liger-kernel-nightly 0.5.2.dev20250122005057__tar.gz → 0.5.2.dev20250124002122__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (209) hide show
  1. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.gitignore +1 -0
  2. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/PKG-INFO +1 -1
  3. liger_kernel_nightly-0.5.2.dev20250124002122/benchmark/README.md +30 -0
  4. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/data/all_benchmark_data.csv +30 -0
  5. liger_kernel_nightly-0.5.2.dev20250124002122/benchmark/scripts/benchmark_kto_loss.py +314 -0
  6. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/pyproject.toml +1 -1
  7. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/README.md +1 -1
  8. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/__init__.py +1 -0
  9. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/functional.py +2 -0
  10. liger_kernel_nightly-0.5.2.dev20250124002122/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  11. liger_kernel_nightly-0.5.2.dev20250124002122/src/liger_kernel/chunked_loss/kto_loss.py +172 -0
  12. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  13. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel_nightly.egg-info/SOURCES.txt +5 -0
  14. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/chunked_loss/test_dpo_loss.py +2 -2
  15. liger_kernel_nightly-0.5.2.dev20250124002122/test/chunked_loss/test_kto_loss.py +353 -0
  16. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/triton/test_triton_monkey_patch.py +8 -1
  17. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/utils.py +59 -27
  18. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/pull_request_template.md +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/workflows/amd-ci.yml +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/workflows/docs.yml +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/workflows/nvi-ci.yml +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/workflows/publish-nightly.yml +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/.github/workflows/publish-release.yml +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/LICENSE +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/Makefile +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/NOTICE +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/README.md +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/__init__.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/benchmarks_visualizer.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/__init__.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_embedding.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_geglu.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_group_norm.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_jsd.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_kl_div.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_rope.py +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/benchmark_swiglu.py +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/benchmark/scripts/utils.py +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/dev/fmt-requirements.txt +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/dev/modal/tests.py +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/dev/modal/tests_bwd.py +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/Examples.md +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/Getting-Started.md +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/High-Level-APIs.md +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/Low-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/acknowledgement.md +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/contributing.md +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/banner.GIF +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/compose.gif +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/e2e-memory.png +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/e2e-tps.png +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/logo-banner.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/patch.gif +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/images/post-training.png +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/index.md +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/docs/license.md +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/alignment/accelerate_config.yaml +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/alignment/run_orpo.py +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/README.md +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/callback.py +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/config/fsdp_config.json +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/img/llama_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/img/qwen_tps.png +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/launch_on_modal.py +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/requirements.txt +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/run_benchmarks.sh +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/run_gemma.sh +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/run_llama.sh +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/run_qwen.sh +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/run_qwen2_vl.sh +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/training.py +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/huggingface/training_multimodal.py +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/lightning/README.md +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/lightning/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/lightning/training.py +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/README.md +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/callback.py +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/medusa_util.py +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/examples/medusa/train.py +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/licenses/LICENSE-Apache-2.0 +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/licenses/LICENSE-MIT-llmc +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/licenses/LICENSE-MIT-triton +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/mkdocs.yml +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/setup.cfg +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/setup.py +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/env_report.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/__init__.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/cross_entropy.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/geglu.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/group_norm.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/jsd.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/kl_div.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/layer_norm.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/rms_norm.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/rope.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/swiglu.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/ops/utils.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/__init__.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/auto_model.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/functional.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/geglu.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/group_norm.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/jsd.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/kl_div.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/layer_norm.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/__init__.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/gemma.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/llama.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/mistral.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/mllama.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/phi3.py +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/rms_norm.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/rope.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/swiglu.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/triton/__init__.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/triton/monkey_patch.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel/utils.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/__init__.py +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/chunked_loss/__init__.py +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/chunked_loss/test_cpo_loss.py +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/chunked_loss/test_orpo_loss.py +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/chunked_loss/test_simpo_loss.py +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/conftest.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/convergence/__init__.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/convergence/test_mini_models.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/convergence/test_mini_models_multimodal.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/convergence/test_mini_models_with_logits.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/tiny_shakespeare.txt +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_auto_model.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_cross_entropy.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_embedding.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_fused_linear_jsd.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_geglu.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_group_norm.py +0 -0
  199. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_jsd.py +0 -0
  200. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_kl_div.py +0 -0
  201. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_layer_norm.py +0 -0
  202. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_mm_int8int2.py +0 -0
  203. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_monkey_patch.py +0 -0
  204. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_qwen2vl_mrope.py +0 -0
  205. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_rms_norm.py +0 -0
  206. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_rope.py +0 -0
  207. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_swiglu.py +0 -0
  208. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_trainer_integration.py +0 -0
  209. {liger_kernel_nightly-0.5.2.dev20250122005057 → liger_kernel_nightly-0.5.2.dev20250124002122}/test/transformers/test_transformers.py +0 -0
@@ -5,6 +5,7 @@ site/
5
5
  .venv/
6
6
  venv/
7
7
  .ipynb_checkpoints/
8
+ .vscode/
8
9
 
9
10
  # Misc
10
11
  .DS_Store
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250122005057
3
+ Version: 0.5.2.dev20250124002122
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -0,0 +1,30 @@
1
+ ## Benchmarking Liger Kernels
2
+
3
+ Follow these steps to benchmark and visualize kernel performance:
4
+
5
+ 1. Create a benchmark script
6
+ - Add your script under `benchmark/scripts/`
7
+ - Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)
8
+
9
+ 2. Run the benchmark
10
+ - Results will be saved to `benchmark/data/all_benchmark_data.csv`
11
+
12
+ Example: Benchmarking KTO Loss
13
+ ```bash
14
+ cd benchmark
15
+ python scripts/benchmark_kto_loss.py
16
+ ```
17
+
18
+ 3. Visualize results
19
+ - Use the visualization script with appropriate parameters
20
+
21
+ Example: Visualizing KTO Loss benchmark results
22
+ ```bash
23
+ python benchmarks_visualizer.py \
24
+ --kernel-name kto_loss \
25
+ --metric-name memory \
26
+ --kernel-operation-mode full
27
+ ```
28
+
29
+ 4. View results
30
+ - Generated plots will be saved in `benchmark/visualizations/`
@@ -715,3 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
715
715
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
716
716
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
717
717
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
718
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,7.841599941253662,7.801983833312988,7.849664211273193,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
719
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,15.568096160888672,15.555737495422363,16.054176330566406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
720
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,31.145376205444336,30.750951766967773,31.5398006439209,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
721
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,61.49708938598633,61.49708938598633,61.49708938598633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
722
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,122.01449584960938,122.01449584960938,122.01449584960938,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
723
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.892335891723633,7.8687615394592285,8.03729248046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
724
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,14.16302490234375,13.813311576843262,15.860223770141602,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
725
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,25.56470489501953,25.564167022705078,25.641658782958984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
726
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,53.0928955078125,53.0928955078125,53.0928955078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
727
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,108.76080322265625,108.76080322265625,108.76080322265625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
728
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),2,8.662687301635742,8.488287925720215,9.611334800720215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
729
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),4,18.40096092224121,17.99224281311035,18.57883644104004,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
730
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.09159851074219,31.708070755004883,32.475128173828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
731
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),16,69.30239868164062,69.30239868164062,69.30239868164062,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
732
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),32,124.2437744140625,124.2437744140625,124.2437744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
733
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,11.449472427368164,11.407564163208008,11.773555755615234,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
734
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,20.871471405029297,20.862951278686523,20.879276275634766,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
735
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,41.16409683227539,40.760780334472656,41.567413330078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
736
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,77.720703125,77.720703125,77.720703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
737
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,156.25794982910156,156.25794982910156,156.25794982910156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
738
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2027.48583984375,2027.48583984375,2027.48583984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
739
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),4,2789.736328125,2789.736328125,2789.736328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
740
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),8,2801.751953125,2801.751953125,2801.751953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
741
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),16,2825.783203125,2825.783203125,2825.783203125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
742
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),32,2873.845703125,2873.845703125,2873.845703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
743
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,3786.7373046875,3786.7373046875,3786.7373046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
744
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.25390625,5544.25390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
745
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
746
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
747
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
@@ -0,0 +1,314 @@
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ import triton
6
+
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
13
+
14
+ from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
15
+ from liger_kernel.utils import infer_device
16
+
17
+ device = infer_device()
18
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
19
+
20
+
21
+ class TorchLMHeadKTO(torch.nn.Module):
22
+ def __init__(
23
+ self,
24
+ H: int,
25
+ V: int,
26
+ dtype: torch.dtype,
27
+ use_bias: bool = False,
28
+ use_ref_bias: bool = False,
29
+ ignore_index: int = -100,
30
+ beta: float = 0.1,
31
+ ):
32
+ from test.chunked_loss.test_kto_loss import HFKTOLoss
33
+
34
+ super().__init__()
35
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype)
36
+ self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype)
37
+ self.KTO_loss = HFKTOLoss(
38
+ ignore_index=ignore_index,
39
+ beta=beta,
40
+ use_ref_model=True,
41
+ ).get_batch_loss_metrics
42
+
43
+ def forward(self, x, ref_x, y, preference_labels, kl=None):
44
+ return self.KTO_loss(
45
+ weight=self.lin.weight,
46
+ _input=x,
47
+ target=y,
48
+ bias=self.lin.bias,
49
+ ref_input=ref_x,
50
+ ref_weight=self.ref_lin.weight,
51
+ ref_bias=self.ref_lin.bias,
52
+ preference_labels=preference_labels,
53
+ kl=kl,
54
+ )
55
+
56
+
57
+ class LigerLMHeadKTO(torch.nn.Module):
58
+ def __init__(
59
+ self,
60
+ H: int,
61
+ V: int,
62
+ dtype: torch.dtype,
63
+ use_bias: bool = False,
64
+ use_ref_bias: bool = False,
65
+ ignore_index: int = -100,
66
+ beta: float = 0.1,
67
+ ):
68
+ super().__init__()
69
+ self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype)
70
+ self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype)
71
+ self.KTO_loss = LigerFusedLinearKTOLoss(
72
+ ignore_index=ignore_index,
73
+ beta=beta,
74
+ use_ref_model=True,
75
+ )
76
+
77
+ def forward(self, x, ref_x, y, preference_labels, kl=None):
78
+ return self.KTO_loss(
79
+ _input=x,
80
+ lin_weight=self.lin.weight,
81
+ target=y,
82
+ preference_labels=preference_labels,
83
+ bias=self.lin.bias,
84
+ ref_input=ref_x,
85
+ ref_weight=self.ref_lin.weight,
86
+ ref_bias=self.ref_lin.bias,
87
+ kl=kl,
88
+ )
89
+
90
+
91
+ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
92
+ B = input.x
93
+ T = input.extra_benchmark_config["T"]
94
+ H = input.extra_benchmark_config["H"]
95
+ V = input.extra_benchmark_config["V"]
96
+ dtype = input.extra_benchmark_config["dtype"]
97
+ bias = input.extra_benchmark_config["bias"]
98
+ beta = input.extra_benchmark_config["beta"]
99
+ ignore_index = input.extra_benchmark_config["ignore_index"]
100
+ provider = input.kernel_provider
101
+
102
+ torch_kto_loss = TorchLMHeadKTO(
103
+ H=H,
104
+ V=V,
105
+ dtype=dtype,
106
+ bias=bias,
107
+ ref_bias=bias,
108
+ ignore_index=ignore_index,
109
+ beta=beta,
110
+ ).to(device)
111
+
112
+ liger_kto_loss = LigerLMHeadKTO(
113
+ H=H,
114
+ V=V,
115
+ dtype=dtype,
116
+ bias=bias,
117
+ ref_bias=bias,
118
+ ignore_index=ignore_index,
119
+ beta=beta,
120
+ ).to(device)
121
+
122
+ # Input shape: [B, T, H]
123
+ _input = torch.randn(B, T, H, device=device, dtype=dtype)
124
+
125
+ # Target shape: [B, T]
126
+ target = torch.randint(V, (B, T), dtype=torch.long, device=device)
127
+
128
+ # Preference labels shape: [B]
129
+ # Create binary preference labels (0 or 1) for each sequence in the batch
130
+ # Used to indicate preferred sequences (1) vs non-preferred sequences (0)
131
+ preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device)
132
+
133
+ # Precomputed KL divergence between policy and reference distributions
134
+ kl = torch.randn(1, device=device, dtype=dtype)
135
+
136
+ # Add ignore_index tokens to simulate padding
137
+ num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
138
+ indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
139
+ target.view(-1)[indices_to_assign] = ignore_index
140
+
141
+ # Add ref_x with the same shape as _input
142
+ ref_input = torch.randn(B, T, H, device=device, dtype=dtype)
143
+
144
+ def fwd():
145
+ if provider == "liger":
146
+ return liger_kto_loss(
147
+ x=_input,
148
+ ref_x=ref_input,
149
+ y=target,
150
+ preference_labels=preference_labels,
151
+ kl=kl,
152
+ )
153
+ elif provider == "huggingface":
154
+ return torch_kto_loss(
155
+ x=_input,
156
+ ref_x=ref_input,
157
+ y=target,
158
+ preference_labels=preference_labels,
159
+ kl=kl,
160
+ )
161
+
162
+ def full():
163
+ y = fwd()
164
+ y.backward()
165
+
166
+ mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
167
+ return SingleBenchmarkRunOutput(
168
+ y_20=mem_20,
169
+ y_50=mem_50,
170
+ y_80=mem_80,
171
+ )
172
+
173
+
174
+ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
175
+ B = input.x
176
+ T = input.extra_benchmark_config["T"]
177
+ H = input.extra_benchmark_config["H"]
178
+ V = input.extra_benchmark_config["V"]
179
+ dtype = input.extra_benchmark_config["dtype"]
180
+ bias = input.extra_benchmark_config["bias"]
181
+ beta = input.extra_benchmark_config["beta"]
182
+ ignore_index = input.extra_benchmark_config["ignore_index"]
183
+ provider = input.kernel_provider
184
+ mode = input.kernel_operation_mode
185
+
186
+ torch_kto_loss = TorchLMHeadKTO(
187
+ H=H,
188
+ V=V,
189
+ dtype=dtype,
190
+ beta=beta,
191
+ ignore_index=ignore_index,
192
+ bias=bias,
193
+ ).to(device)
194
+ liger_kto_loss = LigerLMHeadKTO(
195
+ H=H,
196
+ V=V,
197
+ dtype=dtype,
198
+ beta=beta,
199
+ ignore_index=ignore_index,
200
+ bias=bias,
201
+ ).to(device)
202
+
203
+ # Input shape: [B, T, H]
204
+ _input = torch.randn(B, T, H, device=device, dtype=dtype)
205
+
206
+ # Target shape: [B, T]
207
+ target = torch.randint(V, (B, T), device=device, dtype=torch.long)
208
+
209
+ # Preference labels shape: [B]
210
+ # Create binary preference labels (0 or 1) for each sequence in the batch
211
+ # Used to indicate preferred sequences (1) vs non-preferred sequences (0)
212
+ preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device)
213
+
214
+ # Precomputed KL divergence between policy and reference distributions
215
+ kl = torch.randn(1, device=device, dtype=dtype)
216
+
217
+ # Add ignore_index tokens
218
+ num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
219
+ indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
220
+ target.view(-1)[indices_to_assign] = ignore_index
221
+
222
+ # Add ref_x with the same shape as _input
223
+ ref_input = torch.randn(B, T, H, device=device, dtype=dtype)
224
+
225
+ def fwd():
226
+ if provider == "liger":
227
+ return liger_kto_loss(
228
+ x=_input,
229
+ ref_x=ref_input,
230
+ y=target,
231
+ preference_labels=preference_labels,
232
+ kl=kl,
233
+ )
234
+ elif provider == "huggingface":
235
+ return torch_kto_loss(
236
+ x=_input,
237
+ ref_x=ref_input,
238
+ y=target,
239
+ preference_labels=preference_labels,
240
+ kl=kl,
241
+ )
242
+
243
+ if mode == "forward":
244
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
245
+ fwd,
246
+ rep=100,
247
+ quantiles=QUANTILES,
248
+ )
249
+ elif mode == "backward":
250
+ y = fwd()
251
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
252
+ lambda: y.backward(retain_graph=True),
253
+ grad_to_none=[_input],
254
+ rep=100,
255
+ quantiles=QUANTILES,
256
+ )
257
+ elif mode == "full":
258
+
259
+ def full():
260
+ y = fwd()
261
+ y.backward()
262
+
263
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
264
+ full,
265
+ rep=100,
266
+ quantiles=QUANTILES,
267
+ )
268
+
269
+ return SingleBenchmarkRunOutput(
270
+ y_20=ms_20,
271
+ y_50=ms_50,
272
+ y_80=ms_80,
273
+ )
274
+
275
+
276
+ if __name__ == "__main__":
277
+ args = parse_benchmark_script_args()
278
+
279
+ common_configs = {
280
+ "kernel_name": "kto_loss",
281
+ "x_name": "B",
282
+ "x_label": "Batch Size (B)",
283
+ "x_values": [2**i for i in range(1, 6)],
284
+ "kernel_providers": ["liger", "huggingface"],
285
+ "extra_benchmark_configs": [
286
+ {
287
+ "T": 512,
288
+ "H": 1024,
289
+ "V": 128256,
290
+ "mode": "forward",
291
+ "dtype": torch.bfloat16,
292
+ "bias": True,
293
+ "beta": 0.1,
294
+ "ignore_index": 42,
295
+ }
296
+ ],
297
+ "overwrite": args.overwrite,
298
+ }
299
+
300
+ run_benchmarks(
301
+ bench_test_fn=bench_speed_kto_loss,
302
+ kernel_operation_modes=["forward", "full"],
303
+ metric_name="speed",
304
+ metric_unit="ms",
305
+ **common_configs,
306
+ )
307
+
308
+ run_benchmarks(
309
+ bench_test_fn=bench_memory_kto_loss,
310
+ kernel_operation_modes=["full"],
311
+ metric_name="memory",
312
+ metric_unit="MB",
313
+ **common_configs,
314
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20250122005057"
7
+ version = "0.5.2.dev20250124002122"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,6 +1,6 @@
1
1
  # Liger FlexChunkLoss: Alignment and Distillation loss
2
2
 
3
- Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
3
+ Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
4
4
 
5
5
  ### User interface
6
6
 
@@ -1,4 +1,5 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
3
4
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
4
5
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
5
6
 
@@ -7,3 +8,4 @@ liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
7
8
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
8
9
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
9
10
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
11
+ liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply