liger-kernel-nightly 0.5.2.dev20241223032630__tar.gz → 0.5.2.dev20241228022953__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (204) hide show
  1. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/Makefile +3 -4
  2. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/PKG-INFO +1 -1
  3. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/benchmarks_visualizer.py +5 -10
  4. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_cpo_loss.py +11 -16
  5. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_cross_entropy.py +10 -13
  6. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_dpo_loss.py +16 -30
  7. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_embedding.py +10 -13
  8. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +16 -29
  9. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_fused_linear_jsd.py +30 -43
  10. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_geglu.py +7 -8
  11. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_group_norm.py +15 -28
  12. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_jsd.py +13 -20
  13. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_kl_div.py +10 -17
  14. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_layer_norm.py +11 -16
  15. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_orpo_loss.py +11 -16
  16. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_qwen2vl_mrope.py +23 -48
  17. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_rms_norm.py +9 -10
  18. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_rope.py +21 -42
  19. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_simpo_loss.py +11 -16
  20. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/benchmark_swiglu.py +8 -11
  21. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/utils.py +16 -25
  22. liger_kernel_nightly-0.5.2.dev20241228022953/dev/fmt-requirements.txt +1 -0
  23. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/dev/modal/tests.py +1 -3
  24. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/dev/modal/tests_bwd.py +1 -3
  25. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/alignment/run_orpo.py +4 -4
  26. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/callback.py +16 -34
  27. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/launch_on_modal.py +2 -5
  28. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/training.py +5 -7
  29. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/training_multimodal.py +5 -7
  30. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/lightning/training.py +17 -31
  31. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/callback.py +19 -43
  32. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/medusa_util.py +11 -33
  33. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/train.py +19 -39
  34. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/pyproject.toml +34 -2
  35. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/setup.py +1 -0
  36. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/cpo_loss.py +5 -12
  37. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/dpo_loss.py +1 -4
  38. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  39. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  40. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/orpo_loss.py +2 -6
  41. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/simpo_loss.py +4 -8
  42. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/env_report.py +4 -11
  43. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/cross_entropy.py +7 -10
  44. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/experimental/embedding.py +1 -3
  45. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  46. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
  47. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/fused_linear_jsd.py +11 -29
  48. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/geglu.py +6 -17
  49. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/group_norm.py +11 -28
  50. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/jsd.py +2 -6
  51. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/kl_div.py +4 -7
  52. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/layer_norm.py +3 -5
  53. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/qwen2vl_mrope.py +8 -25
  54. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/rms_norm.py +11 -29
  55. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/rope.py +8 -24
  56. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/swiglu.py +4 -8
  57. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/utils.py +2 -0
  58. liger_kernel_nightly-0.5.2.dev20241228022953/src/liger_kernel/transformers/__init__.py +23 -0
  59. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/auto_model.py +6 -13
  60. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/cross_entropy.py +1 -3
  61. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/experimental/embedding.py +1 -3
  62. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/functional.py +2 -6
  63. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  64. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/geglu.py +1 -4
  65. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/group_norm.py +3 -9
  66. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/jsd.py +1 -3
  67. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/kl_div.py +1 -3
  68. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/layer_norm.py +3 -9
  69. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/gemma.py +18 -40
  70. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/gemma2.py +19 -41
  71. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/llama.py +22 -48
  72. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/mistral.py +14 -26
  73. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/mixtral.py +23 -53
  74. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/mllama.py +16 -36
  75. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/phi3.py +18 -40
  76. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/qwen2.py +18 -40
  77. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/qwen2_vl.py +16 -30
  78. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/monkey_patch.py +43 -117
  79. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/rms_norm.py +4 -4
  80. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/swiglu.py +2 -8
  81. liger_kernel_nightly-0.5.2.dev20241228022953/src/liger_kernel/transformers/trainer/__init__.py +4 -0
  82. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  83. liger_kernel_nightly-0.5.2.dev20241228022953/src/liger_kernel/triton/__init__.py +1 -0
  84. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/triton/monkey_patch.py +1 -3
  85. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  86. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -2
  87. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/chunked_loss/test_cpo_loss.py +10 -22
  88. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/chunked_loss/test_dpo_loss.py +15 -30
  89. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/chunked_loss/test_orpo_loss.py +12 -23
  90. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/chunked_loss/test_simpo_loss.py +11 -19
  91. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/convergence/test_mini_models.py +54 -76
  92. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/convergence/test_mini_models_multimodal.py +35 -73
  93. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/convergence/test_mini_models_with_logits.py +54 -76
  94. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/scripts/generate_tokenized_dataset.py +4 -12
  95. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_auto_model.py +14 -17
  96. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_cross_entropy.py +32 -90
  97. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_embedding.py +6 -19
  98. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_fused_linear_cross_entropy.py +16 -28
  99. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_fused_linear_jsd.py +41 -66
  100. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_geglu.py +3 -5
  101. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_group_norm.py +5 -17
  102. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_jsd.py +25 -46
  103. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_kl_div.py +5 -11
  104. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_layer_norm.py +2 -6
  105. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_mm_int8int2.py +8 -20
  106. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_monkey_patch.py +153 -331
  107. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_qwen2vl_mrope.py +15 -43
  108. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_rms_norm.py +9 -18
  109. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_rope.py +9 -25
  110. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_swiglu.py +8 -15
  111. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/transformers/test_trainer_integration.py +1 -3
  112. liger_kernel_nightly-0.5.2.dev20241228022953/test/transformers/test_transformers.py +16 -0
  113. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/utils.py +28 -55
  114. liger_kernel_nightly-0.5.2.dev20241223032630/.flake8 +0 -10
  115. liger_kernel_nightly-0.5.2.dev20241223032630/.isort.cfg +0 -2
  116. liger_kernel_nightly-0.5.2.dev20241223032630/dev/fmt-requirements.txt +0 -3
  117. liger_kernel_nightly-0.5.2.dev20241223032630/src/liger_kernel/transformers/__init__.py +0 -31
  118. liger_kernel_nightly-0.5.2.dev20241223032630/src/liger_kernel/transformers/trainer/__init__.py +0 -6
  119. liger_kernel_nightly-0.5.2.dev20241223032630/src/liger_kernel/triton/__init__.py +0 -3
  120. liger_kernel_nightly-0.5.2.dev20241223032630/test/transformers/test_transformers.py +0 -18
  121. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/pull_request_template.md +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/workflows/amd-ci.yml +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/workflows/nvi-ci.yml +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/workflows/publish-nightly.yml +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.github/workflows/publish-release.yml +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/.gitignore +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/LICENSE +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/NOTICE +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/README.md +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/__init__.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/data/all_benchmark_data.csv +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/benchmark/scripts/__init__.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/Acknowledgement.md +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/CONTRIBUTING.md +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/License.md +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/banner.GIF +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/compose.gif +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/e2e-memory.png +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/e2e-tps.png +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/logo-banner.png +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/patch.gif +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/docs/images/post-training.png +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/alignment/accelerate_config.yaml +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/README.md +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/config/fsdp_config.json +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/img/llama_tps.png +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/img/qwen_tps.png +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/requirements.txt +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/run_benchmarks.sh +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/run_gemma.sh +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/run_llama.sh +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/run_qwen.sh +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/huggingface/run_qwen2_vl.sh +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/lightning/README.md +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/lightning/requirements.txt +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/README.md +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/requirements.txt +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/licenses/LICENSE-Apache-2.0 +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/licenses/LICENSE-MIT-llmc +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/licenses/LICENSE-MIT-triton +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/setup.cfg +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/__init__.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/README.md +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/chunked_loss/functional.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/ops/__init__.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/model/__init__.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/rope.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel/utils.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/__init__.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/chunked_loss/__init__.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/conftest.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/convergence/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  199. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  200. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/tiny_shakespeare.txt +0 -0
  201. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  202. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  203. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  204. {liger_kernel_nightly-0.5.2.dev20241223032630 → liger_kernel_nightly-0.5.2.dev20241228022953}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -10,10 +10,9 @@ test:
10
10
  # Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
11
11
  # Subsequent commands still run if the previous fails, but return failure at the end
12
12
  checkstyle:
13
- flake8 .; flake8_status=$$?; \
14
- isort .; isort_status=$$?; \
15
- black .; black_status=$$?; \
16
- if [ $$flake8_status -ne 0 ] || [ $$isort_status -ne 0 ] || [ $$black_status -ne 0 ]; then \
13
+ ruff check . --fix; ruff_check_status=$$?; \
14
+ ruff format .; ruff_format_status=$$?; \
15
+ if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
17
16
  exit 1; \
18
17
  fi
19
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241223032630
3
+ Version: 0.5.2.dev20241228022953
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+
3
4
  from argparse import ArgumentParser
4
5
  from dataclasses import dataclass
5
6
 
@@ -39,9 +40,7 @@ def parse_args() -> VisualizationsConfig:
39
40
  VisualizationsConfig: Configuration object for the visualizations script.
40
41
  """
41
42
  parser = ArgumentParser()
42
- parser.add_argument(
43
- "--kernel-name", type=str, required=True, help="Kernel name to benchmark"
44
- )
43
+ parser.add_argument("--kernel-name", type=str, required=True, help="Kernel name to benchmark")
45
44
  parser.add_argument(
46
45
  "--metric-name",
47
46
  type=str,
@@ -54,9 +53,7 @@ def parse_args() -> VisualizationsConfig:
54
53
  required=True,
55
54
  help="Kernel operation mode to visualize (forward/backward/full)",
56
55
  )
57
- parser.add_argument(
58
- "--display", action="store_true", help="Display the visualization"
59
- )
56
+ parser.add_argument("--display", action="store_true", help="Display the visualization")
60
57
  parser.add_argument(
61
58
  "--overwrite",
62
59
  action="store_true",
@@ -126,7 +123,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
126
123
  lines = ax.get_lines()
127
124
  colors = [line.get_color() for line in lines]
128
125
 
129
- for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
126
+ for (_, group_data), color in zip(df.groupby("kernel_provider"), colors, strict=False):
130
127
  # for i, row in group_data.iterrows():
131
128
  y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
132
129
  y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
@@ -145,9 +142,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
145
142
  plt.ylabel(ylabel)
146
143
  plt.tight_layout()
147
144
 
148
- out_path = os.path.join(
149
- VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
150
- )
145
+ out_path = os.path.join(VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png")
151
146
 
152
147
  if config.display:
153
148
  plt.show()
@@ -3,14 +3,13 @@ import sys
3
3
 
4
4
  import torch
5
5
  import triton
6
- from utils import (
7
- QUANTILES,
8
- SingleBenchmarkRunInput,
9
- SingleBenchmarkRunOutput,
10
- _test_memory,
11
- parse_benchmark_script_args,
12
- run_benchmarks,
13
- )
6
+
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
14
13
 
15
14
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
16
15
  from liger_kernel.utils import infer_device
@@ -33,9 +32,7 @@ class TorchLMHeadCPO(torch.nn.Module):
33
32
  from test.chunked_loss.test_cpo_loss import HFCPOLoss
34
33
 
35
34
  super().__init__()
36
- self.lin = torch.nn.Linear(
37
- in_features=H, out_features=V, bias=False, dtype=dtype
38
- )
35
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
39
36
  self.cpo_loss = HFCPOLoss().get_batch_loss_metrics
40
37
 
41
38
  def forward(self, x, y):
@@ -45,9 +42,7 @@ class TorchLMHeadCPO(torch.nn.Module):
45
42
  class LigerLMHeadCPO(torch.nn.Module):
46
43
  def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
47
44
  super().__init__()
48
- self.lin = torch.nn.Linear(
49
- in_features=H, out_features=V, bias=False, dtype=dtype
50
- )
45
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
51
46
  self.cpo_loss = LigerFusedLinearCPOFunction.apply
52
47
 
53
48
  def forward(self, x, y):
@@ -180,12 +175,12 @@ if __name__ == "__main__":
180
175
  kernel_operation_modes=["forward", "full"],
181
176
  metric_name="speed",
182
177
  metric_unit="ms",
183
- **common_configs
178
+ **common_configs,
184
179
  )
185
180
  run_benchmarks(
186
181
  bench_test_fn=bench_memory_fused_linear_cpo_loss,
187
182
  kernel_operation_modes=["full"],
188
183
  metric_name="memory",
189
184
  metric_unit="MB",
190
- **common_configs
185
+ **common_configs,
191
186
  )
@@ -1,14 +1,13 @@
1
1
  import torch
2
2
  import triton
3
+
3
4
  from torch.nn import CrossEntropyLoss
4
- from utils import (
5
- QUANTILES,
6
- SingleBenchmarkRunInput,
7
- SingleBenchmarkRunOutput,
8
- _test_memory,
9
- parse_benchmark_script_args,
10
- run_benchmarks,
11
- )
5
+ from utils import QUANTILES
6
+ from utils import SingleBenchmarkRunInput
7
+ from utils import SingleBenchmarkRunOutput
8
+ from utils import _test_memory
9
+ from utils import parse_benchmark_script_args
10
+ from utils import run_benchmarks
12
11
 
13
12
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
14
13
  from liger_kernel.utils import infer_device
@@ -86,9 +85,7 @@ def bench_speed_cross_entropy(
86
85
  y = fwd()
87
86
  y.backward()
88
87
 
89
- ms_50, ms_20, ms_80 = triton.testing.do_bench(
90
- full, rep=100, quantiles=QUANTILES
91
- )
88
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
92
89
 
93
90
  return SingleBenchmarkRunOutput(
94
91
  y_20=ms_20,
@@ -115,12 +112,12 @@ if __name__ == "__main__":
115
112
  kernel_operation_modes=["forward", "full"],
116
113
  metric_name="speed",
117
114
  metric_unit="ms",
118
- **common_configs
115
+ **common_configs,
119
116
  )
120
117
  run_benchmarks(
121
118
  bench_test_fn=bench_memory_cross_entropy,
122
119
  kernel_operation_modes=["full"],
123
120
  metric_name="memory",
124
121
  metric_unit="MB",
125
- **common_configs
122
+ **common_configs,
126
123
  )
@@ -1,15 +1,13 @@
1
- from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
2
-
3
1
  import torch
4
2
  import triton
5
- from utils import (
6
- QUANTILES,
7
- SingleBenchmarkRunInput,
8
- SingleBenchmarkRunOutput,
9
- _test_memory,
10
- parse_benchmark_script_args,
11
- run_benchmarks,
12
- )
3
+
4
+ from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
5
+ from utils import QUANTILES
6
+ from utils import SingleBenchmarkRunInput
7
+ from utils import SingleBenchmarkRunOutput
8
+ from utils import _test_memory
9
+ from utils import parse_benchmark_script_args
10
+ from utils import run_benchmarks
13
11
 
14
12
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
15
13
  from liger_kernel.utils import infer_device
@@ -28,9 +26,7 @@ class TorchDPOLoss(torch.nn.Module):
28
26
  bias: bool = False,
29
27
  ):
30
28
  super().__init__()
31
- self.lin = torch.nn.Linear(
32
- in_features=H, out_features=V, bias=bias, dtype=dtype
33
- )
29
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
34
30
  self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)
35
31
 
36
32
  def forward(self, x, target):
@@ -53,9 +49,7 @@ class LigerDPOLoss(torch.nn.Module):
53
49
  bias: bool = False,
54
50
  ):
55
51
  super().__init__()
56
- self.lin = torch.nn.Linear(
57
- in_features=H, out_features=V, bias=bias, dtype=dtype
58
- )
52
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
59
53
  self.beta = beta
60
54
  self.ignore_index = ignore_index
61
55
 
@@ -82,12 +76,8 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
82
76
  ignore_index = input.extra_benchmark_config["ignore_index"]
83
77
  provider = input.kernel_provider
84
78
 
85
- torch_dpo_loss = TorchDPOLoss(
86
- H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
87
- ).to(device)
88
- liger_dpo_loss = LigerDPOLoss(
89
- H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
90
- ).to(device)
79
+ torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
80
+ liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
91
81
 
92
82
  # Input shape: [B, T, H]
93
83
  _input = torch.randn(B, T, H, device=device, dtype=dtype)
@@ -129,12 +119,8 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
129
119
  provider = input.kernel_provider
130
120
  mode = input.kernel_operation_mode
131
121
 
132
- torch_dpo_loss = TorchDPOLoss(
133
- H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
134
- ).to(device)
135
- liger_dpo_loss = LigerDPOLoss(
136
- H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
137
- ).to(device)
122
+ torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
123
+ liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
138
124
 
139
125
  # Input shape: [B, T, H]
140
126
  _input = torch.randn(B, T, H, device=device, dtype=dtype)
@@ -215,7 +201,7 @@ if __name__ == "__main__":
215
201
  kernel_operation_modes=["forward", "full"],
216
202
  metric_name="speed",
217
203
  metric_unit="ms",
218
- **common_configs
204
+ **common_configs,
219
205
  )
220
206
 
221
207
  run_benchmarks(
@@ -223,5 +209,5 @@ if __name__ == "__main__":
223
209
  kernel_operation_modes=["full"],
224
210
  metric_name="memory",
225
211
  metric_unit="MB",
226
- **common_configs
212
+ **common_configs,
227
213
  )
@@ -1,14 +1,13 @@
1
1
  import torch
2
2
  import triton
3
+
3
4
  from torch.nn import Embedding
4
- from utils import (
5
- QUANTILES,
6
- SingleBenchmarkRunInput,
7
- SingleBenchmarkRunOutput,
8
- _test_memory,
9
- parse_benchmark_script_args,
10
- run_benchmarks,
11
- )
5
+ from utils import QUANTILES
6
+ from utils import SingleBenchmarkRunInput
7
+ from utils import SingleBenchmarkRunOutput
8
+ from utils import _test_memory
9
+ from utils import parse_benchmark_script_args
10
+ from utils import run_benchmarks
12
11
 
13
12
  from liger_kernel.transformers.experimental.embedding import LigerEmbedding
14
13
  from liger_kernel.utils import infer_device
@@ -50,9 +49,7 @@ def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
50
49
  if mode == "forward":
51
50
  ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
52
51
  elif mode == "full":
53
- ms_50, ms_20, ms_80 = triton.testing.do_bench(
54
- full, quantiles=QUANTILES, rep=100
55
- )
52
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
56
53
  return SingleBenchmarkRunOutput(
57
54
  y_20=ms_20,
58
55
  y_50=ms_50,
@@ -118,12 +115,12 @@ if __name__ == "__main__":
118
115
  kernel_operation_modes=["forward", "full"],
119
116
  metric_name="speed",
120
117
  metric_unit="ms",
121
- **common_configs
118
+ **common_configs,
122
119
  )
123
120
  run_benchmarks(
124
121
  bench_test_fn=bench_memory_embedding,
125
122
  kernel_operation_modes=["full"],
126
123
  metric_name="memory",
127
124
  metric_unit="MB",
128
- **common_configs
125
+ **common_configs,
129
126
  )
@@ -1,17 +1,14 @@
1
1
  import torch
2
2
  import triton
3
- from utils import (
4
- QUANTILES,
5
- SingleBenchmarkRunInput,
6
- SingleBenchmarkRunOutput,
7
- _test_memory,
8
- parse_benchmark_script_args,
9
- run_benchmarks,
10
- )
11
-
12
- from liger_kernel.transformers.fused_linear_cross_entropy import (
13
- LigerFusedLinearCrossEntropyLoss,
14
- )
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
10
+
11
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
12
  from liger_kernel.utils import infer_device
16
13
 
17
14
  device = infer_device()
@@ -28,12 +25,8 @@ class TorchLMHeadCE(torch.nn.Module):
28
25
 
29
26
  def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
30
27
  super().__init__()
31
- self.lin = torch.nn.Linear(
32
- in_features=H, out_features=V, bias=False, dtype=dtype
33
- )
34
- self.ce_loss = torch.nn.CrossEntropyLoss(
35
- ignore_index=ignore_index, reduction="mean"
36
- )
28
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
29
+ self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction="mean")
37
30
 
38
31
  def forward(self, x, y):
39
32
  logits = self.lin(x)
@@ -43,12 +36,8 @@ class TorchLMHeadCE(torch.nn.Module):
43
36
  class LigerLMHeadCE(torch.nn.Module):
44
37
  def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
45
38
  super().__init__()
46
- self.lin = torch.nn.Linear(
47
- in_features=H, out_features=V, bias=False, dtype=dtype
48
- )
49
- self.ce_loss = LigerFusedLinearCrossEntropyLoss(
50
- ignore_index=ignore_index, reduction="mean"
51
- )
39
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
40
+ self.ce_loss = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction="mean")
52
41
 
53
42
  def forward(self, x, y):
54
43
  return self.ce_loss(self.lin.weight, x, y)
@@ -161,9 +150,7 @@ if __name__ == "__main__":
161
150
  "x_label": "B x T",
162
151
  "x_values": [2**i for i in range(12, 16)],
163
152
  "kernel_providers": ["liger", "huggingface"],
164
- "extra_benchmark_configs": [
165
- {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
166
- ],
153
+ "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
167
154
  "overwrite": args.overwrite,
168
155
  }
169
156
 
@@ -172,12 +159,12 @@ if __name__ == "__main__":
172
159
  kernel_operation_modes=["forward", "full"],
173
160
  metric_name="speed",
174
161
  metric_unit="ms",
175
- **common_configs
162
+ **common_configs,
176
163
  )
177
164
  run_benchmarks(
178
165
  bench_test_fn=bench_memory_fused_linear_cross_entropy,
179
166
  kernel_operation_modes=["full"],
180
167
  metric_name="memory",
181
168
  metric_unit="MB",
182
- **common_configs
169
+ **common_configs,
183
170
  )
@@ -1,13 +1,12 @@
1
1
  import torch
2
2
  import triton
3
- from utils import (
4
- QUANTILES,
5
- SingleBenchmarkRunInput,
6
- SingleBenchmarkRunOutput,
7
- _test_memory,
8
- parse_benchmark_script_args,
9
- run_benchmarks,
10
- )
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
11
10
 
12
11
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD
13
12
  from liger_kernel.utils import infer_device
@@ -37,9 +36,9 @@ class TorchJSD(torch.nn.Module):
37
36
  log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
38
37
  log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
39
38
  m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
40
- loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
41
- 1 - self.beta
42
- ) * self.kl(torch.log(m), log_q).sum(dim=-1)
39
+ loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl(
40
+ torch.log(m), log_q
41
+ ).sum(dim=-1)
43
42
 
44
43
  if label is not None:
45
44
  loss = torch.where(label != self.ignore_index, loss, 0.0)
@@ -73,12 +72,8 @@ class TorchLMHeadJSD(torch.nn.Module):
73
72
  temperature: float = 1.0,
74
73
  ):
75
74
  super().__init__()
76
- self.student_lin = torch.nn.Linear(
77
- in_features=H, out_features=V, bias=False, dtype=dtype, device=device
78
- )
79
- self.teacher_lin = torch.nn.Linear(
80
- in_features=H, out_features=V, bias=False, dtype=dtype, device=device
81
- )
75
+ self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
76
+ self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
82
77
  self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype)
83
78
  self.temperature = temperature
84
79
 
@@ -103,15 +98,9 @@ class LigerLMHeadJSD(torch.nn.Module):
103
98
  temperature: float = 1.0,
104
99
  ):
105
100
  super().__init__()
106
- self.student_lin = torch.nn.Linear(
107
- in_features=H, out_features=V, bias=False, dtype=dtype, device=device
108
- )
109
- self.teacher_lin = torch.nn.Linear(
110
- in_features=H, out_features=V, bias=False, dtype=dtype, device=device
111
- )
112
- self.fused_jsd = LigerFusedLinearJSD(
113
- jsd_beta=beta, ignore_index=ignore_index, temperature=temperature
114
- )
101
+ self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
102
+ self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
103
+ self.fused_jsd = LigerFusedLinearJSD(jsd_beta=beta, ignore_index=ignore_index, temperature=temperature)
115
104
 
116
105
  def forward(self, student_input, teacher_input, label=None):
117
106
  return self.fused_jsd(
@@ -141,12 +130,12 @@ def bench_memory_fused_linear_jsd(
141
130
  liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
142
131
 
143
132
  # init the linear in all FusedLinearJSDs with the same weights
144
- torch_lm_head_jsd.student_lin.weight.data = (
145
- liger_lm_head_jsd.student_lin.weight.data
146
- ) = torch.rand(V, H, device=device, dtype=dtype)
147
- torch_lm_head_jsd.teacher_lin.weight.data = (
148
- liger_lm_head_jsd.teacher_lin.weight.data
149
- ) = torch.rand(V, H, device=device, dtype=dtype)
133
+ torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
134
+ V, H, device=device, dtype=dtype
135
+ )
136
+ torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
137
+ V, H, device=device, dtype=dtype
138
+ )
150
139
 
151
140
  student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
152
141
  teacher_input = torch.rand(BT, H, dtype=dtype, device=device)
@@ -189,12 +178,12 @@ def bench_speed_fused_linear_jsd(
189
178
  liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
190
179
 
191
180
  # init the linear in all FusedLinearJSDs with the same weights
192
- torch_lm_head_jsd.student_lin.weight.data = (
193
- liger_lm_head_jsd.student_lin.weight.data
194
- ) = torch.rand(V, H, device=device, dtype=dtype)
195
- torch_lm_head_jsd.teacher_lin.weight.data = (
196
- liger_lm_head_jsd.teacher_lin.weight.data
197
- ) = torch.rand(V, H, device=device, dtype=dtype)
181
+ torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
182
+ V, H, device=device, dtype=dtype
183
+ )
184
+ torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
185
+ V, H, device=device, dtype=dtype
186
+ )
198
187
 
199
188
  student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
200
189
  teacher_input = torch.rand(BT, H, dtype=dtype, device=device)
@@ -251,9 +240,7 @@ if __name__ == "__main__":
251
240
  "x_label": "B x T",
252
241
  "x_values": [2**i for i in range(10, 14)],
253
242
  "kernel_providers": ["liger", "torch"],
254
- "extra_benchmark_configs": [
255
- {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
256
- ],
243
+ "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
257
244
  "overwrite": args.overwrite,
258
245
  }
259
246
 
@@ -262,12 +249,12 @@ if __name__ == "__main__":
262
249
  kernel_operation_modes=["forward", "full"],
263
250
  metric_name="speed",
264
251
  metric_unit="ms",
265
- **common_configs
252
+ **common_configs,
266
253
  )
267
254
  run_benchmarks(
268
255
  bench_test_fn=bench_memory_fused_linear_jsd,
269
256
  kernel_operation_modes=["full"],
270
257
  metric_name="memory",
271
258
  metric_unit="MB",
272
- **common_configs
259
+ **common_configs,
273
260
  )
@@ -1,15 +1,14 @@
1
1
  import torch
2
2
  import triton
3
+
3
4
  from transformers.models.llama.configuration_llama import LlamaConfig
4
5
  from transformers.models.llama.modeling_llama import LlamaMLP
5
- from utils import (
6
- QUANTILES,
7
- SingleBenchmarkRunInput,
8
- SingleBenchmarkRunOutput,
9
- _test_memory,
10
- parse_benchmark_script_args,
11
- run_benchmarks,
12
- )
6
+ from utils import QUANTILES
7
+ from utils import SingleBenchmarkRunInput
8
+ from utils import SingleBenchmarkRunOutput
9
+ from utils import _test_memory
10
+ from utils import parse_benchmark_script_args
11
+ from utils import run_benchmarks
13
12
 
14
13
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
15
14
  from liger_kernel.utils import infer_device
@@ -1,13 +1,12 @@
1
1
  import torch
2
2
  import triton
3
- from utils import (
4
- QUANTILES,
5
- SingleBenchmarkRunInput,
6
- SingleBenchmarkRunOutput,
7
- _test_memory,
8
- parse_benchmark_script_args,
9
- run_benchmarks,
10
- )
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
11
10
 
12
11
  from liger_kernel.transformers.group_norm import LigerGroupNorm
13
12
  from liger_kernel.utils import infer_device
@@ -27,12 +26,8 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
27
26
  dtype = extra_benchmark_config["dtype"]
28
27
 
29
28
  x_shape = (M, C, H)
30
- triton_ln = LigerGroupNorm(
31
- num_channels=C, num_groups=C // channels_per_group, eps=eps
32
- ).to(device)
33
- torch_ln = torch.nn.GroupNorm(
34
- num_groups=C // channels_per_group, num_channels=C, eps=eps
35
- ).to(device)
29
+ triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
30
+ torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
36
31
 
37
32
  x = torch.randn(x_shape, dtype=dtype, device=device)
38
33
  dy = torch.randn_like(x)
@@ -45,9 +40,7 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
45
40
  return torch_ln(x)
46
41
 
47
42
  if mode == "forward":
48
- ms_50, ms_20, ms_80 = triton.testing.do_bench(
49
- y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
50
- )
43
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
51
44
  elif mode == "backward":
52
45
  y = y_fwd()
53
46
  ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -62,9 +55,7 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
62
55
  y = y_fwd()
63
56
  y.backward(dy, retain_graph=True)
64
57
 
65
- ms_50, ms_20, ms_80 = triton.testing.do_bench(
66
- full, quantiles=QUANTILES, grad_to_none=[x], rep=500
67
- )
58
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
68
59
 
69
60
  return SingleBenchmarkRunOutput(
70
61
  y_20=ms_20,
@@ -84,12 +75,8 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu
84
75
  dtype = extra_benchmark_config["dtype"]
85
76
 
86
77
  x_shape = (M, C, H)
87
- triton_ln = LigerGroupNorm(
88
- num_channels=C, num_groups=C // channels_per_group, eps=eps
89
- ).to(device)
90
- torch_ln = torch.nn.GroupNorm(
91
- num_groups=C // channels_per_group, num_channels=C, eps=eps
92
- ).to(device)
78
+ triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
79
+ torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
93
80
 
94
81
  x = torch.randn(x_shape, dtype=dtype, device=device)
95
82
  dy = torch.randn_like(x)
@@ -139,12 +126,12 @@ if __name__ == "__main__":
139
126
  kernel_operation_modes=["forward", "full", "backward"],
140
127
  metric_name="speed",
141
128
  metric_unit="ms",
142
- **common_configs
129
+ **common_configs,
143
130
  )
144
131
  run_benchmarks(
145
132
  bench_test_fn=bench_memory_group_norm,
146
133
  kernel_operation_modes=["full", "forward", "backward"],
147
134
  metric_name="memory",
148
135
  metric_unit="MB",
149
- **common_configs
136
+ **common_configs,
150
137
  )