liger-kernel-nightly 0.5.3.dev20250220230230__tar.gz → 0.5.3.dev20250221002845__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/PKG-INFO +2 -1
  2. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/README.md +1 -0
  3. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/data/all_benchmark_data.csv +37 -0
  4. liger_kernel_nightly-0.5.3.dev20250221002845/benchmark/scripts/benchmark_tvd.py +136 -0
  5. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/pyproject.toml +1 -1
  6. liger_kernel_nightly-0.5.3.dev20250221002845/src/liger_kernel/ops/tvd.py +208 -0
  7. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/__init__.py +1 -0
  8. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/functional.py +15 -1
  9. liger_kernel_nightly-0.5.3.dev20250221002845/src/liger_kernel/transformers/tvd.py +15 -0
  10. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel_nightly.egg-info/PKG-INFO +2 -1
  11. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  12. liger_kernel_nightly-0.5.3.dev20250221002845/test/transformers/test_tvd.py +195 -0
  13. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  14. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  15. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/pull_request_template.md +0 -0
  16. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/workflows/amd-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/workflows/docs.yml +0 -0
  18. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/workflows/intel-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/workflows/nvi-ci.yml +0 -0
  20. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/workflows/publish-nightly.yml +0 -0
  21. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.github/workflows/publish-release.yml +0 -0
  22. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/.gitignore +0 -0
  23. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/LICENSE +0 -0
  24. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/Makefile +0 -0
  25. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/NOTICE +0 -0
  26. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/README.md +0 -0
  27. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/benchmarks_visualizer.py +0 -0
  29. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_embedding.py +0 -0
  35. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_geglu.py +0 -0
  38. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_group_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_jsd.py +0 -0
  40. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_kl_div.py +0 -0
  41. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  43. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  45. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_rope.py +0 -0
  47. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  48. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/benchmark_swiglu.py +0 -0
  49. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/benchmark/scripts/utils.py +0 -0
  50. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/dev/fmt-requirements.txt +0 -0
  51. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/dev/modal/tests.py +0 -0
  52. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/dev/modal/tests_bwd.py +0 -0
  53. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/Examples.md +0 -0
  54. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/Getting-Started.md +0 -0
  55. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/High-Level-APIs.md +0 -0
  56. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/Low-Level-APIs.md +0 -0
  57. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/acknowledgement.md +0 -0
  58. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/contributing.md +0 -0
  59. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/banner.GIF +0 -0
  60. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/compose.gif +0 -0
  61. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/e2e-memory.png +0 -0
  62. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/e2e-tps.png +0 -0
  63. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/logo-banner.png +0 -0
  64. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/patch.gif +0 -0
  65. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/images/post-training.png +0 -0
  66. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/index.md +0 -0
  67. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/docs/license.md +0 -0
  68. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/alignment/accelerate_config.yaml +0 -0
  69. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/alignment/run_orpo.py +0 -0
  70. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/README.md +0 -0
  71. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/callback.py +0 -0
  72. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/config/fsdp_config.json +0 -0
  73. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  74. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  75. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  76. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/img/llama_tps.png +0 -0
  77. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  78. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/img/qwen_tps.png +0 -0
  79. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/launch_on_modal.py +0 -0
  80. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/requirements.txt +0 -0
  81. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/run_benchmarks.sh +0 -0
  82. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/run_gemma.sh +0 -0
  83. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/run_llama.sh +0 -0
  84. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/run_qwen.sh +0 -0
  85. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/run_qwen2_vl.sh +0 -0
  86. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/training.py +0 -0
  87. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/huggingface/training_multimodal.py +0 -0
  88. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/lightning/README.md +0 -0
  89. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/lightning/requirements.txt +0 -0
  90. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/lightning/training.py +0 -0
  91. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/README.md +0 -0
  92. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/callback.py +0 -0
  93. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  94. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  95. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  96. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  97. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  98. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  99. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  102. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/medusa_util.py +0 -0
  103. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/requirements.txt +0 -0
  104. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  105. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/examples/medusa/train.py +0 -0
  106. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/licenses/LICENSE-Apache-2.0 +0 -0
  107. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  108. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  109. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/licenses/LICENSE-MIT-llmc +0 -0
  110. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/licenses/LICENSE-MIT-triton +0 -0
  111. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/mkdocs.yml +0 -0
  112. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/setup.cfg +0 -0
  113. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/setup.py +0 -0
  114. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/__init__.py +0 -0
  115. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/README.md +0 -0
  116. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  117. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  118. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  119. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/functional.py +0 -0
  120. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  121. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  122. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  123. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  124. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  125. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/env_report.py +0 -0
  130. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/__init__.py +0 -0
  131. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/cross_entropy.py +0 -0
  132. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  133. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  134. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  135. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  136. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/geglu.py +0 -0
  137. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/group_norm.py +0 -0
  138. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/jsd.py +0 -0
  139. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/kl_div.py +0 -0
  140. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/layer_norm.py +0 -0
  141. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  142. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/rms_norm.py +0 -0
  143. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/rope.py +0 -0
  144. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/swiglu.py +0 -0
  145. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/ops/utils.py +0 -0
  146. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/auto_model.py +0 -0
  147. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  148. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  149. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  150. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  151. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/geglu.py +0 -0
  152. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/group_norm.py +0 -0
  153. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/jsd.py +0 -0
  154. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/kl_div.py +0 -0
  155. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/layer_norm.py +0 -0
  156. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/__init__.py +0 -0
  157. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/gemma.py +0 -0
  158. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  159. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/llama.py +0 -0
  160. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/mistral.py +0 -0
  161. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  162. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/mllama.py +0 -0
  163. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/phi3.py +0 -0
  164. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  165. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  166. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  167. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  168. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/rms_norm.py +0 -0
  169. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/rope.py +0 -0
  170. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/swiglu.py +0 -0
  171. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  172. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  173. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  174. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/triton/__init__.py +0 -0
  175. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/triton/monkey_patch.py +0 -0
  176. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel/utils.py +0 -0
  177. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  178. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  179. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  180. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/__init__.py +0 -0
  181. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/__init__.py +0 -0
  182. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_cpo_loss.py +0 -0
  183. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_dpo_loss.py +0 -0
  184. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_grpo_loss.py +0 -0
  185. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_jsd_loss.py +0 -0
  186. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_kto_loss.py +0 -0
  187. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_orpo_loss.py +0 -0
  188. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/chunked_loss/test_simpo_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/conftest.py +0 -0
  190. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/__init__.py +0 -0
  191. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/bf16/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/bf16/test_mini_models.py +0 -0
  193. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  194. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  195. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/fp32/__init__.py +0 -0
  196. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/fp32/test_mini_models.py +0 -0
  197. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  198. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  199. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  200. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  201. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  202. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/tiny_shakespeare.txt +0 -0
  203. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  204. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  205. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  206. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_auto_model.py +0 -0
  207. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_cross_entropy.py +0 -0
  208. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_embedding.py +0 -0
  209. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  210. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_fused_linear_jsd.py +0 -0
  211. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_geglu.py +0 -0
  212. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_group_norm.py +0 -0
  213. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_jsd.py +0 -0
  214. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_kl_div.py +0 -0
  215. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_layer_norm.py +0 -0
  216. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_mm_int8int2.py +0 -0
  217. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_monkey_patch.py +0 -0
  218. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_qwen2vl_mrope.py +0 -0
  219. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_rms_norm.py +0 -0
  220. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_rope.py +0 -0
  221. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_swiglu.py +0 -0
  222. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_trainer_integration.py +0 -0
  223. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/transformers/test_transformers.py +0 -0
  224. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/triton/test_triton_monkey_patch.py +0 -0
  225. {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221002845}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250220230230
3
+ Version: 0.5.3.dev20250221002845
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -349,6 +349,7 @@ loss.backward()
349
349
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
350
350
  | JSD | `liger_kernel.transformers.LigerJSD` |
351
351
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
352
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
352
353
 
353
354
  ### Experimental Kernels
354
355
 
@@ -301,6 +301,7 @@ loss.backward()
301
301
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
302
302
  | JSD | `liger_kernel.transformers.LigerJSD` |
303
303
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
304
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
304
305
 
305
306
  ### Experimental Kernels
306
307
 
@@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
505
505
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
506
506
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
507
507
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
508
+ tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
509
+ tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
510
+ tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
511
+ tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
512
+ tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
513
+ tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
514
+ tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
515
+ tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
516
+ tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
517
+ tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
518
+ tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
519
+ tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
520
+ tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
521
+ tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
522
+ tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
523
+ tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
524
+ tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
525
+ tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
526
+ tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
527
+ tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
528
+ tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
529
+ tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
530
+ tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
531
+ tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
532
+ tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
533
+ tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
534
+ tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
535
+ tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
536
+ tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
537
+ tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
538
+ tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
539
+ tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
540
+ tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
541
+ tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
542
+ tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
543
+ tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
508
544
  group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
509
545
  group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
510
546
  group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
@@ -769,3 +805,4 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
769
805
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
770
806
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
771
807
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
808
+
@@ -0,0 +1,136 @@
1
+ import torch
2
+ import triton
3
+ from utils import (
4
+ QUANTILES,
5
+ SingleBenchmarkRunInput,
6
+ SingleBenchmarkRunOutput,
7
+ _test_memory,
8
+ parse_benchmark_script_args,
9
+ run_benchmarks,
10
+ )
11
+
12
+ from liger_kernel.transformers.tvd import LigerTVDLoss
13
+
14
+
15
+ class TorchTVDLoss(torch.nn.Module):
16
+ def __init__(self, reduction="batchmean"):
17
+ super(TorchTVDLoss, self).__init__()
18
+ self.reduction = reduction
19
+
20
+ def forward(self, p, q):
21
+ tvd = torch.abs(p - q) / 2.0
22
+ if self.reduction == "mean":
23
+ return torch.sum(tvd) / (p.size(0) * p.size(1))
24
+ elif self.reduction == "sum":
25
+ return torch.sum(tvd)
26
+ elif self.reduction == "none":
27
+ return tvd
28
+ elif self.reduction == "batchmean":
29
+ return torch.sum(tvd) / p.size(0)
30
+ else:
31
+ raise ValueError("Invalid reduction type.")
32
+
33
+
34
+ S, E = 12, 18
35
+
36
+
37
+ def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
38
+ reduction = "batchmean"
39
+ V = input.x
40
+ B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
41
+ torch_tvd = TorchTVDLoss(reduction=reduction)
42
+ liger_tvd = LigerTVDLoss(reduction=reduction)
43
+
44
+ _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
45
+ target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
46
+
47
+ def fwd():
48
+ if input.kernel_provider == "liger":
49
+ return liger_tvd(_input, target)
50
+ else:
51
+ return torch_tvd(_input, target)
52
+
53
+ if input.kernel_operation_mode == "forward":
54
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
55
+ elif input.kernel_operation_mode == "backward":
56
+ y = fwd()
57
+
58
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
59
+ lambda: y.backward(retain_graph=True),
60
+ quantiles=QUANTILES,
61
+ grad_to_none=[_input],
62
+ rep=100,
63
+ )
64
+ elif input.kernel_operation_mode == "full":
65
+
66
+ def full():
67
+ y = fwd()
68
+ y.backward(retain_graph=True)
69
+
70
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
71
+ full, quantiles=QUANTILES, rep=100
72
+ )
73
+ return SingleBenchmarkRunOutput(
74
+ y_20=ms_20,
75
+ y_50=ms_50,
76
+ y_80=ms_80,
77
+ )
78
+
79
+
80
+ def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
81
+ reduction = "batchmean"
82
+ torch_tvd = TorchTVDLoss(reduction=reduction)
83
+ liger_tvd = LigerTVDLoss(reduction=reduction)
84
+
85
+ V = input.x
86
+ B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
87
+
88
+ _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
89
+ target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
90
+
91
+ def fwd():
92
+ if input.kernel_provider == "liger":
93
+ return liger_tvd(_input, target)
94
+ else:
95
+ return torch_tvd(_input, target)
96
+
97
+ def full():
98
+ y = fwd()
99
+ y.backward(retain_graph=True)
100
+
101
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
102
+
103
+ return SingleBenchmarkRunOutput(
104
+ y_20=mem_20,
105
+ y_50=mem_50,
106
+ y_80=mem_80,
107
+ )
108
+
109
+
110
+ if __name__ == "__main__":
111
+ args = parse_benchmark_script_args()
112
+ common_args = {
113
+ "kernel_name": "tvd",
114
+ "x_name": "V",
115
+ "x_label": "vocab size",
116
+ "x_values": [2**i for i in range(12, 18)],
117
+ "kernel_providers": ["liger", "torch"],
118
+ "extra_benchmark_configs": [{"B": 8, "T": 2048}],
119
+ "overwrite": args.overwrite,
120
+ }
121
+
122
+ run_benchmarks(
123
+ bench_test_fn=bench_memory_tvd,
124
+ kernel_operation_modes=["full"],
125
+ metric_name="memory",
126
+ metric_unit="MB",
127
+ **common_args,
128
+ )
129
+
130
+ run_benchmarks(
131
+ bench_test_fn=bench_speed_tvd,
132
+ kernel_operation_modes=["forward", "full"],
133
+ metric_name="speed",
134
+ metric_unit="ms",
135
+ **common_args,
136
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.3.dev20250220230230"
7
+ version = "0.5.3.dev20250221002845"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,208 @@
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+ MAX_FUSED_SIZE = 65536 // 4
10
+
11
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
12
+
13
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
14
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
15
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
16
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
17
+
18
+ _str_to_reduction_mode = {
19
+ "none": _REDUCTION_MODE_NONE.value,
20
+ "sum": _REDUCTION_MODE_SUM.value,
21
+ "mean": _REDUCTION_MODE_MEAN.value,
22
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
23
+ }
24
+
25
+
26
+ def get_num_warps(BLOCK_SIZE):
27
+ num_warps = 4
28
+ if BLOCK_SIZE >= 32768:
29
+ num_warps = 32
30
+ elif BLOCK_SIZE >= 8192:
31
+ num_warps = 16
32
+ elif BLOCK_SIZE >= 2048:
33
+ num_warps = 8
34
+
35
+ return num_warps
36
+
37
+
38
+ @triton.jit
39
+ def _tv_distance_kernel(
40
+ p_ptr,
41
+ p_stride,
42
+ q_ptr,
43
+ q_stride,
44
+ loss_ptr,
45
+ loss_stride,
46
+ grads_ptr,
47
+ grads_stride,
48
+ label_ptr,
49
+ ignore_index: tl.constexpr,
50
+ n_cols,
51
+ BLOCK_SIZE: tl.constexpr,
52
+ HAS_LABEL: tl.constexpr,
53
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
54
+ ):
55
+ pid = tl.program_id(0).to(tl.int64)
56
+ p_ptr += pid * p_stride
57
+ q_ptr += pid * q_stride
58
+ loss_ptr += pid * loss_stride
59
+ grads_ptr += pid * grads_stride
60
+ label_ptr += pid
61
+
62
+ base_offsets = tl.arange(0, BLOCK_SIZE)
63
+
64
+ if HAS_LABEL:
65
+ label = tl.load(label_ptr)
66
+ if label == ignore_index:
67
+ for i in range(0, n_cols, BLOCK_SIZE):
68
+ offsets = i + base_offsets
69
+ mask = offsets < n_cols
70
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
71
+ if reduction == _REDUCTION_MODE_NONE:
72
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
73
+ return
74
+
75
+ loss_sum = 0.0
76
+ for i in range(0, n_cols, BLOCK_SIZE):
77
+ offsets = i + base_offsets
78
+ mask = offsets < n_cols
79
+
80
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
81
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
82
+
83
+ # TVD(P || Q) = 0.5 * |P - Q|
84
+ tv_loss = 0.5 * tl.abs(p - q)
85
+
86
+ grad_res = tl.where(p > q, 0.5, -0.5)
87
+
88
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
89
+
90
+ if reduction == _REDUCTION_MODE_NONE:
91
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
92
+ else:
93
+ loss_sum += tl.sum(tv_loss, axis=0)
94
+
95
+ if reduction != _REDUCTION_MODE_NONE:
96
+ tl.store(loss_ptr, loss_sum)
97
+
98
+
99
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
100
+ BT, V = p.shape
101
+
102
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
103
+ num_warps = get_num_warps(BLOCK_SIZE)
104
+
105
+ grid = (BT,)
106
+
107
+ reduction = _str_to_reduction_mode[reduction]
108
+
109
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
110
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
111
+ grads = torch.empty_like(p)
112
+
113
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
114
+
115
+ _tv_distance_kernel[grid](
116
+ p,
117
+ p.stride(0),
118
+ q,
119
+ q.stride(0),
120
+ output_tensor,
121
+ output_tensor.stride(0),
122
+ grads,
123
+ grads.stride(0),
124
+ shift_labels if has_label else torch.empty(1, device=p.device),
125
+ ignore_index,
126
+ V,
127
+ BLOCK_SIZE=BLOCK_SIZE,
128
+ HAS_LABEL=has_label,
129
+ num_warps=num_warps,
130
+ reduction=reduction,
131
+ )
132
+
133
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
134
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
135
+ elif reduction == _REDUCTION_MODE_SUM.value:
136
+ return output_tensor.sum(dim=0), grads
137
+ elif reduction == _REDUCTION_MODE_MEAN.value:
138
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
139
+ else:
140
+ return output_tensor, grads
141
+
142
+
143
+ def tvd_backward_triton(grad_output, grads):
144
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
145
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
146
+ return grads
147
+
148
+ return grads * grad_output
149
+
150
+
151
+ class LigerTVDLossFunction(torch.autograd.Function):
152
+ """
153
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
154
+ """
155
+
156
+ @staticmethod
157
+ @ensure_contiguous
158
+ def forward(
159
+ ctx,
160
+ p: torch.Tensor,
161
+ q: torch.Tensor,
162
+ shift_labels: Optional[torch.Tensor] = None,
163
+ reduction: REDUCTION_LITERAL = "batchmean",
164
+ ignore_index: int = -100,
165
+ ) -> torch.Tensor:
166
+ """A forward pass for the Total Variation Distance Loss.
167
+
168
+ Args:
169
+ ctx: Torch autograd context
170
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
171
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
172
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
173
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
174
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
175
+
176
+ Returns:
177
+ torch.Tensor: The computed Total Variation Distance Loss.
178
+ """
179
+ has_label = False
180
+ if shift_labels is not None:
181
+ assert shift_labels.shape == (
182
+ p.shape[0],
183
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ shift_labels = shift_labels.contiguous()
185
+ has_label = True
186
+
187
+ loss, grads = tv_distance_forward_triton(
188
+ p, q, shift_labels, reduction, ignore_index, has_label
189
+ )
190
+ ctx.save_for_backward(grads)
191
+ return loss
192
+
193
+ @staticmethod
194
+ @ensure_contiguous
195
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
196
+ """A backward pass for the Total Variation Distance Loss.
197
+
198
+ Args:
199
+ ctx: Torch autograd context
200
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
201
+
202
+ Returns:
203
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
204
+ """
205
+ (grads,) = ctx.saved_tensors
206
+ grads = tvd_backward_triton(grad_output, grads)
207
+
208
+ return grads, None, None, None, None
@@ -18,6 +18,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
18
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
19
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
20
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
21
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
21
22
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
23
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
24
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -12,7 +12,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
12
12
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
13
13
  from liger_kernel.ops.rope import LigerRopeFunction
14
14
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
15
-
15
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
16
16
 
17
17
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
18
18
  # `weight` and `size_average` are placeholders and not implemented yet
@@ -156,6 +156,20 @@ def liger_kl_div(
156
156
  eps,
157
157
  )
158
158
 
159
+ def liger_tvd(
160
+ input,
161
+ target,
162
+ shift_labels=None,
163
+ reduction: str = "mean",
164
+ ignore_index: int = -100,
165
+ ):
166
+ return LigerTVDLossFunction.apply(
167
+ input,
168
+ target,
169
+ shift_labels,
170
+ reduction,
171
+ ignore_index,
172
+ )
159
173
 
160
174
  def liger_layer_norm(X, W, B, eps):
161
175
  return LigerLayerNormFunction.apply(X, W, B, eps)
@@ -0,0 +1,15 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
4
+
5
+
6
+ class LigerTVDLoss(nn.Module):
7
+ def __init__(self, reduction="batchmean", ignore_index: int = -100):
8
+ super(LigerTVDLoss, self).__init__()
9
+ self.reduction = reduction
10
+ self.ignore_index = ignore_index
11
+
12
+ def forward(self, p, q, shift_labels=None):
13
+ return LigerTVDLossFunction.apply(
14
+ p, q, shift_labels, self.reduction, self.ignore_index
15
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250220230230
3
+ Version: 0.5.3.dev20250221002845
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -349,6 +349,7 @@ loss.backward()
349
349
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
350
350
  | JSD | `liger_kernel.transformers.LigerJSD` |
351
351
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
352
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
352
353
 
353
354
  ### Experimental Kernels
354
355
 
@@ -39,6 +39,7 @@ benchmark/scripts/benchmark_rms_norm.py
39
39
  benchmark/scripts/benchmark_rope.py
40
40
  benchmark/scripts/benchmark_simpo_loss.py
41
41
  benchmark/scripts/benchmark_swiglu.py
42
+ benchmark/scripts/benchmark_tvd.py
42
43
  benchmark/scripts/utils.py
43
44
  dev/fmt-requirements.txt
44
45
  dev/modal/tests.py
@@ -131,6 +132,7 @@ src/liger_kernel/ops/qwen2vl_mrope.py
131
132
  src/liger_kernel/ops/rms_norm.py
132
133
  src/liger_kernel/ops/rope.py
133
134
  src/liger_kernel/ops/swiglu.py
135
+ src/liger_kernel/ops/tvd.py
134
136
  src/liger_kernel/ops/utils.py
135
137
  src/liger_kernel/ops/experimental/embedding.py
136
138
  src/liger_kernel/ops/experimental/mm_int8int2.py
@@ -151,6 +153,7 @@ src/liger_kernel/transformers/rms_norm.py
151
153
  src/liger_kernel/transformers/rope.py
152
154
  src/liger_kernel/transformers/swiglu.py
153
155
  src/liger_kernel/transformers/trainer_integration.py
156
+ src/liger_kernel/transformers/tvd.py
154
157
  src/liger_kernel/transformers/experimental/embedding.py
155
158
  src/liger_kernel/transformers/model/__init__.py
156
159
  src/liger_kernel/transformers/model/gemma.py
@@ -216,4 +219,5 @@ test/transformers/test_rope.py
216
219
  test/transformers/test_swiglu.py
217
220
  test/transformers/test_trainer_integration.py
218
221
  test/transformers/test_transformers.py
222
+ test/transformers/test_tvd.py
219
223
  test/triton/test_triton_monkey_patch.py