liger-kernel-nightly 0.5.2.dev20250130024630__tar.gz → 0.5.2.dev20250130172806__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (212) hide show
  1. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/data/all_benchmark_data.csv +24 -0
  3. liger_kernel_nightly-0.5.2.dev20250130172806/benchmark/scripts/benchmark_distill_jsd_loss.py +261 -0
  4. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/pyproject.toml +1 -1
  5. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/__init__.py +1 -0
  6. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/functional.py +2 -0
  7. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +14 -5
  8. liger_kernel_nightly-0.5.2.dev20250130172806/src/liger_kernel/chunked_loss/jsd_loss.py +154 -0
  9. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  10. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel_nightly.egg-info/SOURCES.txt +3 -0
  11. liger_kernel_nightly-0.5.2.dev20250130172806/test/chunked_loss/test_jsd_loss.py +318 -0
  12. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/utils.py +4 -1
  13. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/pull_request_template.md +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/workflows/amd-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/workflows/docs.yml +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/workflows/nvi-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/workflows/publish-nightly.yml +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.github/workflows/publish-release.yml +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/.gitignore +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/LICENSE +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/Makefile +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/NOTICE +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/README.md +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/README.md +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/benchmarks_visualizer.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_embedding.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_geglu.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_group_norm.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_jsd.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_kl_div.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_rope.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/benchmark_swiglu.py +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/benchmark/scripts/utils.py +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/dev/fmt-requirements.txt +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/dev/modal/tests.py +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/dev/modal/tests_bwd.py +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/Examples.md +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/Getting-Started.md +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/High-Level-APIs.md +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/Low-Level-APIs.md +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/acknowledgement.md +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/contributing.md +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/banner.GIF +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/compose.gif +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/e2e-memory.png +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/e2e-tps.png +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/logo-banner.png +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/patch.gif +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/images/post-training.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/index.md +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/docs/license.md +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/alignment/accelerate_config.yaml +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/alignment/run_orpo.py +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/README.md +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/callback.py +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/config/fsdp_config.json +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/img/llama_tps.png +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/img/qwen_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/launch_on_modal.py +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/requirements.txt +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/run_benchmarks.sh +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/run_gemma.sh +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/run_llama.sh +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/run_qwen.sh +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/run_qwen2_vl.sh +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/training.py +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/huggingface/training_multimodal.py +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/lightning/README.md +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/lightning/requirements.txt +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/lightning/training.py +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/README.md +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/callback.py +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/medusa_util.py +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/requirements.txt +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/examples/medusa/train.py +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/licenses/LICENSE-Apache-2.0 +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/licenses/LICENSE-MIT-llmc +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/licenses/LICENSE-MIT-triton +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/mkdocs.yml +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/setup.cfg +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/setup.py +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/__init__.py +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/README.md +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/env_report.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/__init__.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/cross_entropy.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/geglu.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/group_norm.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/jsd.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/kl_div.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/layer_norm.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/rms_norm.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/rope.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/swiglu.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/ops/utils.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/__init__.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/auto_model.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/functional.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/geglu.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/group_norm.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/jsd.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/kl_div.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/layer_norm.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/__init__.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/gemma.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/llama.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/mistral.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/mllama.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/phi3.py +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/rms_norm.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/rope.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/swiglu.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/triton/__init__.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/triton/monkey_patch.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel/utils.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/__init__.py +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/chunked_loss/__init__.py +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/chunked_loss/test_cpo_loss.py +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/chunked_loss/test_dpo_loss.py +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/chunked_loss/test_kto_loss.py +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/chunked_loss/test_orpo_loss.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/chunked_loss/test_simpo_loss.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/conftest.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/convergence/__init__.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/convergence/test_mini_models.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/convergence/test_mini_models_multimodal.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/convergence/test_mini_models_with_logits.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/tiny_shakespeare.txt +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_auto_model.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_cross_entropy.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_embedding.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_fused_linear_jsd.py +0 -0
  199. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_geglu.py +0 -0
  200. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_group_norm.py +0 -0
  201. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_jsd.py +0 -0
  202. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_kl_div.py +0 -0
  203. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_layer_norm.py +0 -0
  204. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_mm_int8int2.py +0 -0
  205. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_monkey_patch.py +0 -0
  206. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_qwen2vl_mrope.py +0 -0
  207. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_rms_norm.py +0 -0
  208. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_rope.py +0 -0
  209. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_swiglu.py +0 -0
  210. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_trainer_integration.py +0 -0
  211. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/transformers/test_transformers.py +0 -0
  212. {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130172806}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250130024630
3
+ Version: 0.5.2.dev20250130172806
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -745,3 +745,27 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.253906
745
745
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
746
746
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
747
747
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
748
+ distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
749
+ distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
750
+ distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
751
+ distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
752
+ distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
753
+ distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
754
+ distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
755
+ distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
756
+ distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
757
+ distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
758
+ distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
759
+ distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
760
+ distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
761
+ distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
762
+ distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
763
+ distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
764
+ distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
765
+ distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
766
+ distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
767
+ distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
768
+ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
769
+ distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
770
+ distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
771
+ distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
@@ -0,0 +1,261 @@
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ import triton
6
+
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
13
+
14
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
15
+ from liger_kernel.utils import infer_device
16
+
17
+ device = infer_device()
18
+
19
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
20
+
21
+
22
+ class TorchJSDLoss(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ H: int,
26
+ V: int,
27
+ dtype: torch.dtype,
28
+ weight_hard_loss: float = 0.5,
29
+ weight_soft_loss: float = 0.5,
30
+ ignore_index: int = -100,
31
+ temperature: float = 1.0,
32
+ bias: bool = False,
33
+ ):
34
+ from test.chunked_loss.test_jsd_loss import HFJSDLoss
35
+
36
+ super().__init__()
37
+ self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
38
+ self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
39
+ self.jsd_loss = HFJSDLoss(
40
+ ignore_index=ignore_index,
41
+ weight_hard_loss=weight_hard_loss,
42
+ weight_soft_loss=weight_soft_loss,
43
+ temperature=temperature,
44
+ ).get_batch_loss_metrics
45
+
46
+ def forward(self, student, teacher, target):
47
+ return self.jsd_loss(
48
+ student,
49
+ self.student_lin.weight,
50
+ teacher,
51
+ self.teacher_lin.weight,
52
+ target,
53
+ )
54
+
55
+
56
+ class LigerJSDLoss(torch.nn.Module):
57
+ def __init__(
58
+ self,
59
+ H: int,
60
+ V: int,
61
+ dtype: torch.dtype,
62
+ weight_hard_loss: float = 0.5,
63
+ weight_soft_loss: float = 0.5,
64
+ ignore_index: int = -100,
65
+ temperature: float = 1.0,
66
+ bias: bool = False,
67
+ ):
68
+ super().__init__()
69
+ self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
70
+ self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
71
+ self.weight_hard_loss = weight_hard_loss
72
+ self.weight_soft_loss = weight_soft_loss
73
+ self.ignore_index = ignore_index
74
+ self.temperature = temperature
75
+ self.jsd_loss = LigerFusedLinearJSDFunction.apply
76
+
77
+ def forward(self, student, teacher, target):
78
+ return self.jsd_loss(
79
+ student,
80
+ self.student_lin.weight,
81
+ teacher,
82
+ self.teacher_lin.weight,
83
+ target,
84
+ self.weight_hard_loss,
85
+ self.weight_soft_loss,
86
+ )
87
+
88
+
89
+ def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
90
+ BT = input.x
91
+ H = input.extra_benchmark_config["H"]
92
+ V = input.extra_benchmark_config["V"]
93
+ dtype = input.extra_benchmark_config["dtype"]
94
+ bias = input.extra_benchmark_config["bias"]
95
+ weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
96
+ weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
97
+ ignore_index = input.extra_benchmark_config["ignore_index"]
98
+ provider = input.kernel_provider
99
+
100
+ torch_jsd_loss = TorchJSDLoss(
101
+ H=H,
102
+ V=V,
103
+ dtype=dtype,
104
+ ignore_index=ignore_index,
105
+ bias=bias,
106
+ weight_hard_loss=weight_hard_loss,
107
+ weight_soft_loss=weight_soft_loss,
108
+ ).to(device)
109
+ liger_jsd_loss = LigerJSDLoss(
110
+ H=H,
111
+ V=V,
112
+ dtype=dtype,
113
+ ignore_index=ignore_index,
114
+ bias=bias,
115
+ weight_hard_loss=weight_hard_loss,
116
+ weight_soft_loss=weight_soft_loss,
117
+ ).to(device)
118
+
119
+ _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
120
+ student_input1 = _tensor.detach().clone().requires_grad_(True)
121
+ student_input2 = _tensor.detach().clone().requires_grad_(True)
122
+
123
+ teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
124
+
125
+ target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
126
+
127
+ def fwd():
128
+ if provider == "liger":
129
+ return liger_jsd_loss(student_input1, teacher_input, target)
130
+ elif provider == "torch":
131
+ return torch_jsd_loss(student_input2, teacher_input, target)
132
+
133
+ def full():
134
+ y = fwd()
135
+ y.backward()
136
+
137
+ mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
138
+ return SingleBenchmarkRunOutput(
139
+ y_20=mem_20,
140
+ y_50=mem_50,
141
+ y_80=mem_80,
142
+ )
143
+
144
+
145
+ def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
146
+ BT = input.x
147
+ H = input.extra_benchmark_config["H"]
148
+ V = input.extra_benchmark_config["V"]
149
+ dtype = input.extra_benchmark_config["dtype"]
150
+ bias = input.extra_benchmark_config["bias"]
151
+ weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
152
+ weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
153
+ ignore_index = input.extra_benchmark_config["ignore_index"]
154
+ provider = input.kernel_provider
155
+ mode = input.kernel_operation_mode
156
+
157
+ torch_jsd_loss = TorchJSDLoss(
158
+ H=H,
159
+ V=V,
160
+ dtype=dtype,
161
+ ignore_index=ignore_index,
162
+ bias=bias,
163
+ weight_hard_loss=weight_hard_loss,
164
+ weight_soft_loss=weight_soft_loss,
165
+ ).to(device)
166
+ liger_jsd_loss = LigerJSDLoss(
167
+ H=H,
168
+ V=V,
169
+ dtype=dtype,
170
+ ignore_index=ignore_index,
171
+ bias=bias,
172
+ weight_hard_loss=weight_hard_loss,
173
+ weight_soft_loss=weight_soft_loss,
174
+ ).to(device)
175
+
176
+ _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
177
+ student_input1 = _tensor.detach().clone().requires_grad_(True)
178
+ student_input2 = _tensor.detach().clone().requires_grad_(True)
179
+
180
+ teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
181
+
182
+ target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
183
+
184
+ def fwd():
185
+ if provider == "liger":
186
+ return liger_jsd_loss(student_input1, teacher_input, target)
187
+ elif provider == "torch":
188
+ return torch_jsd_loss(student_input2, teacher_input, target)
189
+
190
+ if mode == "forward":
191
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
192
+ fwd,
193
+ rep=100,
194
+ quantiles=QUANTILES,
195
+ )
196
+ elif mode == "backward":
197
+ y = fwd()
198
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
199
+ lambda: y.backward(retain_graph=True),
200
+ grad_to_none=[student_input1, student_input2],
201
+ rep=100,
202
+ quantiles=QUANTILES,
203
+ )
204
+ elif mode == "full":
205
+
206
+ def full():
207
+ y = fwd()
208
+ y.backward()
209
+
210
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
211
+ full,
212
+ rep=100,
213
+ quantiles=QUANTILES,
214
+ )
215
+
216
+ return SingleBenchmarkRunOutput(
217
+ y_20=ms_20,
218
+ y_50=ms_50,
219
+ y_80=ms_80,
220
+ )
221
+
222
+
223
+ if __name__ == "__main__":
224
+ args = parse_benchmark_script_args()
225
+
226
+ common_configs = {
227
+ "kernel_name": "distill_jsd_loss",
228
+ "x_name": "BT",
229
+ "x_label": "B x T",
230
+ "x_values": [2**i for i in range(10, 14)],
231
+ "kernel_providers": ["liger", "torch"],
232
+ "extra_benchmark_configs": [
233
+ {
234
+ "H": 4096,
235
+ "V": 128256,
236
+ "mode": "forward",
237
+ "dtype": torch.bfloat16,
238
+ "bias": False,
239
+ "weight_hard_loss": 0.5,
240
+ "weight_soft_loss": 0.5,
241
+ "ignore_index": -100,
242
+ }
243
+ ],
244
+ "overwrite": args.overwrite,
245
+ }
246
+
247
+ run_benchmarks(
248
+ bench_test_fn=bench_speed_jsd_loss,
249
+ kernel_operation_modes=["forward", "full"],
250
+ metric_name="speed",
251
+ metric_unit="ms",
252
+ **common_configs,
253
+ )
254
+
255
+ run_benchmarks(
256
+ bench_test_fn=bench_memory_jsd_loss,
257
+ kernel_operation_modes=["full"],
258
+ metric_name="memory",
259
+ metric_unit="MB",
260
+ **common_configs,
261
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20250130024630"
7
+ version = "0.5.2.dev20250130172806"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
5
6
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -1,11 +1,13 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
3
4
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
4
5
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
5
6
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
6
7
 
7
8
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
8
9
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
10
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
9
11
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
10
12
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
11
13
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -17,6 +17,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
17
17
  Args:
18
18
  student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
19
19
  teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
20
+ Returns:
21
+ torch.Tensor: Sum of distillation losses for the chunk. The class will handle
22
+ converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
20
23
  """
21
24
  raise NotImplementedError("Distillation loss function must be implemented.")
22
25
 
@@ -71,10 +74,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
71
74
  weight_hard_loss=0.5,
72
75
  weight_soft_loss=0.5,
73
76
  compute_ce_loss=True,
77
+ temperature=1,
74
78
  **loss_kwargs,
75
79
  ):
76
80
  """
77
- Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
81
+ Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
78
82
  Args:
79
83
  distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
80
84
  student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
@@ -84,11 +88,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
84
88
  target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
85
89
  student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
86
90
  teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
87
- full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
91
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
88
92
  ignore_index (int): Index to ignore for loss computation.
89
93
  weight_hard_loss (float): Weight for hard loss.
90
94
  weight_soft_loss (float): Weight for soft loss.
91
95
  compute_ce_loss (bool): Whether to compute CE loss.
96
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
92
97
  loss_kwargs (dict): Additional arguments for the loss function.
93
98
  """
94
99
  (
@@ -107,6 +112,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
107
112
  compute_ce_loss=compute_ce_loss,
108
113
  )
109
114
 
115
+ student_logits_chunk /= temperature
116
+ teacher_logits_chunk /= temperature
117
+
110
118
  hard_loss /= full_target.shape[0]
111
119
 
112
120
  soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
@@ -130,6 +138,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
130
138
  ignore_index=-100,
131
139
  weight_hard_loss=0.5,
132
140
  weight_soft_loss=0.5,
141
+ beta=0.5,
133
142
  compute_ce_loss=True,
134
143
  temperature=1.0,
135
144
  compiled=True,
@@ -152,6 +161,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
152
161
  ignore_index (int): Index to ignore for loss computation.
153
162
  weight_hard_loss (float): Weight for hard/task loss.
154
163
  weight_soft_loss (float): Weight for soft/distillation loss.
164
+ beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
155
165
  compute_ce_loss (bool): Whether to compute CE loss.
156
166
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
157
167
  compiled (bool): Whether to use torch compile for chunk accumulation.
@@ -170,7 +180,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
170
180
  ignore_index=ignore_index,
171
181
  weight_hard_loss=weight_hard_loss,
172
182
  weight_soft_loss=weight_soft_loss,
183
+ beta=beta,
173
184
  compute_ce_loss=compute_ce_loss,
185
+ temperature=temperature,
174
186
  **loss_kwargs,
175
187
  )
176
188
 
@@ -225,9 +237,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
225
237
  if compiled:
226
238
  accumulate_chunk = torch.compile(accumulate_chunk)
227
239
 
228
- student_input /= temperature
229
- teacher_input /= temperature
230
-
231
240
  num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
232
241
  _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
233
242
  _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
10
+ """
11
+ Compute JSD loss (Jensen-Shannon Divergence Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
16
+ Returns:
17
+ torch.Tensor: Jensen-Shannon Divergence loss
18
+ """
19
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
20
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
+
22
+ # Compute probabilities (only required for mean calculation)
23
+ mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
+ log_mean_probs = mean_probs.log()
25
+
26
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
+
29
+ # JSD is the weighted average of the KL divergences
30
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
+ return jsd_loss
32
+
33
+ @staticmethod
34
+ def forward(
35
+ ctx,
36
+ student_input: torch.Tensor,
37
+ student_weight: torch.Tensor,
38
+ teacher_input: torch.Tensor,
39
+ teacher_weight: torch.Tensor,
40
+ true_labels: torch.LongTensor,
41
+ weight_hard_loss: float = 0.5,
42
+ weight_soft_loss: float = 0.5,
43
+ beta: float = 0.5,
44
+ ignore_index: int = -100,
45
+ temperature: float = 1.0,
46
+ compiled: bool = True,
47
+ ):
48
+ """
49
+ Fused linear layer with JSD distillation loss.
50
+ Args:
51
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
52
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
53
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
54
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
55
+ true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
56
+ weight_hard_loss (float): Weight for hard loss.
57
+ weight_soft_loss (float): Weight for soft loss.
58
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
59
+ ignore_index (int): Index to ignore in loss computation
60
+ temperature (float): Temperature for softening/sharpening distributions
61
+ compiled (bool): Whether to use torch compile
62
+ Returns:
63
+ torch.Tensor: Computed loss
64
+ """
65
+ return LigerFusedLinearDistillationBase.forward(
66
+ ctx=ctx,
67
+ student_input=student_input,
68
+ student_weight=student_weight,
69
+ teacher_input=teacher_input,
70
+ teacher_weight=teacher_weight,
71
+ target=true_labels,
72
+ loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
73
+ chunk_size=1,
74
+ weight_hard_loss=weight_hard_loss,
75
+ weight_soft_loss=weight_soft_loss,
76
+ beta=beta,
77
+ ignore_index=ignore_index,
78
+ temperature=temperature,
79
+ compiled=compiled,
80
+ )
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output):
84
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
85
+
86
+ return (*grads, None, None, None, None, None, None, None)
87
+
88
+
89
+ class LigerFusedLinearJSDLoss(torch.nn.Module):
90
+ """
91
+ Fused linear layer with JSD distillation loss.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ weight_hard_loss: float = 0.5,
97
+ weight_soft_loss: float = 0.5,
98
+ beta: float = 0.5,
99
+ ignore_index: int = -100,
100
+ temperature: float = 1.0,
101
+ compiled: bool = True,
102
+ ):
103
+ """
104
+ Args:
105
+ weight_hard_loss (float): Weight for hard loss.
106
+ weight_soft_loss (float): Weight for soft loss.
107
+ ignore_index (int): Index to ignore in the loss
108
+ temperature (float): Temperature for softening distributions
109
+ compiled (bool): Whether to use torch compile
110
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
111
+ """
112
+ super().__init__()
113
+ assert temperature != 0, "Temperature cannot be 0."
114
+ self.weight_hard_loss = weight_hard_loss
115
+ self.weight_soft_loss = weight_soft_loss
116
+ self.ignore_index = ignore_index
117
+ self.temperature = temperature
118
+ self.compiled = compiled
119
+ self.beta = beta
120
+
121
+ def forward(
122
+ self,
123
+ student_input: torch.Tensor,
124
+ student_weight: torch.Tensor,
125
+ teacher_input: torch.Tensor,
126
+ teacher_weight: torch.Tensor,
127
+ true_labels: torch.LongTensor,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Compute the JSD distillation loss.
131
+
132
+ Args:
133
+ student_input (torch.Tensor): Student input tensor
134
+ student_weight (torch.Tensor): Student weight tensor
135
+ teacher_input (torch.Tensor): Teacher input tensor
136
+ teacher_weight (torch.Tensor): Teacher weight tensor
137
+ true_labels (torch.LongTensor): Target labels tensor
138
+
139
+ Returns:
140
+ torch.Tensor: Computed loss
141
+ """
142
+ return LigerFusedLinearJSDFunction.apply(
143
+ student_input,
144
+ student_weight,
145
+ teacher_input,
146
+ teacher_weight,
147
+ true_labels,
148
+ self.weight_hard_loss,
149
+ self.weight_soft_loss,
150
+ self.beta,
151
+ self.ignore_index,
152
+ self.temperature,
153
+ self.compiled,
154
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250130024630
3
+ Version: 0.5.2.dev20250130172806
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -21,6 +21,7 @@ benchmark/data/all_benchmark_data.csv
21
21
  benchmark/scripts/__init__.py
22
22
  benchmark/scripts/benchmark_cpo_loss.py
23
23
  benchmark/scripts/benchmark_cross_entropy.py
24
+ benchmark/scripts/benchmark_distill_jsd_loss.py
24
25
  benchmark/scripts/benchmark_dpo_loss.py
25
26
  benchmark/scripts/benchmark_embedding.py
26
27
  benchmark/scripts/benchmark_fused_linear_cross_entropy.py
@@ -110,6 +111,7 @@ src/liger_kernel/chunked_loss/functional.py
110
111
  src/liger_kernel/chunked_loss/fused_linear_distillation.py
111
112
  src/liger_kernel/chunked_loss/fused_linear_preference.py
112
113
  src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py
114
+ src/liger_kernel/chunked_loss/jsd_loss.py
113
115
  src/liger_kernel/chunked_loss/kto_loss.py
114
116
  src/liger_kernel/chunked_loss/orpo_loss.py
115
117
  src/liger_kernel/chunked_loss/simpo_loss.py
@@ -172,6 +174,7 @@ test/utils.py
172
174
  test/chunked_loss/__init__.py
173
175
  test/chunked_loss/test_cpo_loss.py
174
176
  test/chunked_loss/test_dpo_loss.py
177
+ test/chunked_loss/test_jsd_loss.py
175
178
  test/chunked_loss/test_kto_loss.py
176
179
  test/chunked_loss/test_orpo_loss.py
177
180
  test/chunked_loss/test_simpo_loss.py