liger-kernel-nightly 0.5.2.dev20250108102127__tar.gz → 0.5.2.dev20250109023714__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (197) hide show
  1. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_orpo_loss.py +6 -4
  3. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/fused_linear_preference.py +40 -12
  5. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/orpo_loss.py +5 -2
  6. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/trainer/orpo_trainer.py +16 -4
  7. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  8. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_orpo_loss.py +10 -8
  9. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/utils.py +5 -3
  10. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  11. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  12. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/pull_request_template.md +0 -0
  13. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/amd-ci.yml +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/nvi-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/publish-nightly.yml +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/publish-release.yml +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.gitignore +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/LICENSE +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/Makefile +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/NOTICE +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/README.md +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/__init__.py +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/benchmarks_visualizer.py +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/data/all_benchmark_data.csv +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/__init__.py +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_embedding.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_geglu.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_group_norm.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_jsd.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_kl_div.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_rope.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_swiglu.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/utils.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/dev/fmt-requirements.txt +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/dev/modal/tests.py +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/dev/modal/tests_bwd.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/Acknowledgement.md +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/CONTRIBUTING.md +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/License.md +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/banner.GIF +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/compose.gif +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/e2e-memory.png +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/e2e-tps.png +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/logo-banner.png +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/patch.gif +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/post-training.png +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/alignment/accelerate_config.yaml +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/alignment/run_orpo.py +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/README.md +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/callback.py +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/config/fsdp_config.json +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/llama_tps.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/qwen_tps.png +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/launch_on_modal.py +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/requirements.txt +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_benchmarks.sh +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_gemma.sh +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_llama.sh +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_qwen.sh +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_qwen2_vl.sh +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/training.py +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/training_multimodal.py +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/lightning/README.md +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/lightning/requirements.txt +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/lightning/training.py +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/README.md +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/callback.py +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/medusa_util.py +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/requirements.txt +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/train.py +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-Apache-2.0 +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-llmc +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-triton +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/setup.cfg +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/setup.py +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/__init__.py +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/README.md +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/functional.py +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/env_report.py +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/__init__.py +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/cross_entropy.py +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/geglu.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/group_norm.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/jsd.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/kl_div.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/layer_norm.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/rms_norm.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/rope.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/swiglu.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/utils.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/__init__.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/auto_model.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/functional.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/geglu.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/group_norm.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/jsd.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/kl_div.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/layer_norm.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/__init__.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/gemma.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/llama.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/mistral.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/mllama.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/phi3.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/rms_norm.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/rope.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/swiglu.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/triton/__init__.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/triton/monkey_patch.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/utils.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/__init__.py +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/__init__.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_cpo_loss.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_dpo_loss.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_simpo_loss.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/conftest.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/__init__.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/test_mini_models.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/test_mini_models_multimodal.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/test_mini_models_with_logits.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare.txt +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_auto_model.py +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_cross_entropy.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_embedding.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_fused_linear_jsd.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_geglu.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_group_norm.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_jsd.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_kl_div.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_layer_norm.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_mm_int8int2.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_monkey_patch.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_qwen2vl_mrope.py +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_rms_norm.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_rope.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_swiglu.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_trainer_integration.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_transformers.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108102127
3
+ Version: 0.5.2.dev20250109023714
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -45,12 +45,13 @@ def bench_memory_fused_linear_orpo_loss(
45
45
 
46
46
  _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
47
47
  target = torch.randint(V, (B, T), dtype=torch.long, device=device)
48
+ nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
48
49
 
49
50
  def fwd():
50
51
  if provider == "liger":
51
- return liger_lm_head_orpo(_input, target)
52
+ return liger_lm_head_orpo(_input, target, nll_target)
52
53
  elif provider == "huggingface":
53
- return torch_lm_head_orpo(_input, target)
54
+ return torch_lm_head_orpo(_input, target, nll_target)
54
55
 
55
56
  def full():
56
57
  y = fwd()
@@ -91,12 +92,13 @@ def bench_speed_fused_linear_orpo_loss(
91
92
 
92
93
  _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
93
94
  target = torch.randint(V, (B, T), dtype=torch.long, device=device)
95
+ nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
94
96
 
95
97
  def fwd():
96
98
  if provider == "liger":
97
- return liger_lm_head_orpo(_input, target)
99
+ return liger_lm_head_orpo(_input, target, nll_target)
98
100
  elif provider == "huggingface":
99
- return torch_lm_head_orpo(_input, target)
101
+ return torch_lm_head_orpo(_input, target, nll_target)
100
102
 
101
103
  if mode == "forward":
102
104
  ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20250108102127"
7
+ version = "0.5.2.dev20250109023714"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -27,6 +27,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
27
27
  alpha=1.0,
28
28
  beta=0.1,
29
29
  compute_nll_loss=True,
30
+ nll_target=None,
30
31
  compiled=True,
31
32
  use_ref_model=False,
32
33
  ref_input=None,
@@ -58,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
58
59
  alpha (float): Weight for the NLL loss.
59
60
  beta (float): Weight for the preference loss.
60
61
  compute_nll_loss (bool): Whether to compute NLL loss.
62
+ nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
61
63
  compiled (bool): Whether to use torch compile for chunk accumulation.
62
64
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
63
65
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
@@ -96,11 +98,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
96
98
  use_ref_model=use_ref_model,
97
99
  ref_weight=ref_weight,
98
100
  ref_bias=ref_bias,
101
+ full_nll_target=nll_target,
99
102
  average_log_prob=average_log_prob,
100
103
  **loss_kwargs,
101
104
  )
102
105
 
103
- def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
106
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
104
107
  """
105
108
  Fused forward and backward pass for a chunk of input and target.
106
109
  """
@@ -111,13 +114,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
111
114
  target_chunk,
112
115
  bias,
113
116
  ref_input_chunk=ref_input_chunk,
117
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
114
118
  )
115
119
  else:
116
120
  return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
117
- input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
121
+ input_chunk,
122
+ weight,
123
+ target_chunk,
124
+ ref_input_chunk=ref_input_chunk,
125
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
118
126
  )
119
127
 
120
- def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
128
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
121
129
  if bias is not None:
122
130
  (
123
131
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
@@ -132,7 +140,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
132
140
  *aux_outputs,
133
141
  ),
134
142
  ),
135
- ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
143
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
136
144
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
137
145
  else:
138
146
  (
@@ -148,7 +156,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
148
156
  *aux_outputs,
149
157
  ),
150
158
  ),
151
- ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
159
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
152
160
 
153
161
  # Accumulate gradients
154
162
  grad_weight.add_(chunk_grad_weight)
@@ -191,6 +199,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
191
199
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
192
200
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
193
201
 
202
+ if nll_target is not None:
203
+ _chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
204
+
194
205
  if use_ref_model:
195
206
  _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
196
207
  _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
@@ -202,6 +213,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
213
  rejected_target_chunk,
203
214
  ref_chosen_input_chunk,
204
215
  ref_rejected_input_chunk,
216
+ chosen_nll_target_chunk,
205
217
  ) in zip(
206
218
  _chosen_input_chunks,
207
219
  _rejected_input_chunks,
@@ -209,6 +221,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
209
221
  _rejected_target_chunks,
210
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
211
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
+ (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
212
225
  strict=False,
213
226
  ):
214
227
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
@@ -222,9 +235,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
235
  torch._dynamo.mark_dynamic(target_chunk, 1)
223
236
  torch._dynamo.mark_dynamic(target, 1)
224
237
  torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
238
+ torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
225
239
 
226
240
  # accumulate loss, gradients, and metrics
227
- accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
241
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
228
242
 
229
243
  # combine grad_chosen_inputs and grad_rejected_inputs
230
244
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -258,7 +272,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
258
272
  grad_weight = grad_weight * grad_output[0][0]
259
273
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
260
274
 
261
- return grad_input, grad_weight, None, grad_bias, None, None, None
275
+ return grad_input, grad_weight, None, grad_bias, None, None, None, None
262
276
 
263
277
  @staticmethod
264
278
  def chunk_forward(
@@ -268,6 +282,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
268
282
  bias=None,
269
283
  ignore_index=-100,
270
284
  compute_nll_loss=True,
285
+ chosen_nll_target_chunk=None,
271
286
  average_log_prob=True,
272
287
  ):
273
288
  len_chosen_chunk = target_chunk.shape[0] // 2
@@ -278,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
278
293
 
279
294
  chosen_nll_loss = 0.0
280
295
  if compute_nll_loss:
296
+ nll_labels = (
297
+ chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
298
+ )
281
299
  chosen_nll_loss = F.nll_loss(
282
300
  log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
283
- target_chunk[:len_chosen_chunk].view(-1),
301
+ nll_labels.view(-1),
284
302
  reduction="sum",
285
303
  ignore_index=ignore_index,
286
304
  )
@@ -324,6 +342,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
324
342
  ref_input_chunk=None,
325
343
  ref_weight=None,
326
344
  ref_bias=None,
345
+ full_nll_target=None,
346
+ chosen_nll_target_chunk=None,
327
347
  average_log_prob=True,
328
348
  **loss_kwargs,
329
349
  ):
@@ -343,6 +363,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
343
363
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
344
364
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
345
365
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
366
+ full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
367
+ chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
346
368
  average_log_prob (bool): Whether to average log probabilities or the sum.
347
369
  loss_kwargs (dict): Additional arguments for the loss function.
348
370
  """
@@ -359,9 +381,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
359
381
  bias=bias,
360
382
  ignore_index=ignore_index,
361
383
  compute_nll_loss=compute_nll_loss,
384
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
362
385
  average_log_prob=average_log_prob,
363
386
  )
364
- chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
387
+ if full_nll_target is not None:
388
+ chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
389
+ else:
390
+ chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
391
+
365
392
  chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
366
393
  rejected_logits_mean = rejected_logits.sum() / (
367
394
  full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
@@ -372,9 +399,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
372
399
  (
373
400
  ref_chosen_logps,
374
401
  ref_rejected_logps,
375
- ref_chosen_logits,
376
- ref_rejected_logits,
377
- ref_chosen_nll_loss,
402
+ _,
403
+ _,
404
+ _,
378
405
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
379
406
  ref_input_chunk,
380
407
  ref_weight,
@@ -382,6 +409,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
382
409
  ref_bias,
383
410
  ignore_index=ignore_index,
384
411
  compute_nll_loss=False, # We don't need NLL loss for the reference model
412
+ chosen_nll_target_chunk=None,
385
413
  average_log_prob=average_log_prob,
386
414
  )
387
415
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
@@ -52,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
52
52
  ignore_index=-100,
53
53
  beta=0.1,
54
54
  compute_nll_loss=True,
55
+ nll_target=None,
55
56
  compiled=True,
56
57
  ):
57
58
  return LigerFusedLinearPreferenceBase.forward(
@@ -64,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
64
65
  ignore_index=ignore_index,
65
66
  beta=beta,
66
67
  compute_nll_loss=compute_nll_loss,
68
+ nll_target=nll_target,
67
69
  compiled=compiled,
68
70
  )
69
71
 
70
72
  @staticmethod
71
73
  def backward(ctx, *grad_output):
72
74
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
- return *grads, None, None, None, None
75
+ return *grads, None, None, None, None, None
74
76
 
75
77
 
76
78
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -96,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
96
98
  self.compute_nll_loss = compute_nll_loss
97
99
  self.compiled = compiled
98
100
 
99
- def forward(self, lin_weight, _input, target, bias=None):
101
+ def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
100
102
  return LigerFusedLinearORPOFunction.apply(
101
103
  _input,
102
104
  lin_weight,
@@ -105,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
105
107
  self.ignore_index,
106
108
  self.beta,
107
109
  self.compute_nll_loss,
110
+ nll_target,
108
111
  self.compiled,
109
112
  )
@@ -93,6 +93,13 @@ class LigerORPOTrainer(ORPOTrainer):
93
93
  if self.aux_loss_enabled:
94
94
  model_kwargs["output_router_logits"] = True
95
95
 
96
+ if self.is_encoder_decoder:
97
+ labels = concatenated_batch["concatenated_labels"].clone()
98
+ else:
99
+ labels = concatenated_batch["concatenated_input_ids"].clone()
100
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
101
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
102
+
96
103
  if isinstance(model, FullyShardedDataParallel):
97
104
  outputs = _FSDPForwardRedirection()(
98
105
  model,
@@ -114,15 +121,20 @@ class LigerORPOTrainer(ORPOTrainer):
114
121
 
115
122
  orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
116
123
 
117
- def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
118
- return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias)
124
+ def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
125
+ return orpo_loss_fn(
126
+ lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
127
+ )
119
128
 
120
129
  orpo_loss, aux_outputs = _FSDPForwardRedirection()(
121
130
  model,
122
131
  orpo_partial,
123
132
  model.lm_head,
124
- outputs.last_hidden_state,
125
- concatenated_batch["concatenated_labels"],
133
+ outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
134
+ concatenated_batch["concatenated_labels"][:, 1:]
135
+ if not self.is_encoder_decoder
136
+ else concatenated_batch["concatenated_labels"],
137
+ labels[:, 1:] if not self.is_encoder_decoder else labels,
126
138
  )
127
139
  # if aux_loss_enabled, add the aux_loss to the orpo_loss
128
140
  if self.aux_loss_enabled:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108102127
3
+ Version: 0.5.2.dev20250109023714
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -86,8 +86,8 @@ class TorchLMHeadORPO(torch.nn.Module):
86
86
  self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
87
87
  self.orpo_loss = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics
88
88
 
89
- def forward(self, x, y):
90
- return self.orpo_loss(self.lin.weight, x, y, self.lin.bias)
89
+ def forward(self, x, y, nll_target=None):
90
+ return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target)
91
91
 
92
92
 
93
93
  class LigerLMHeadORPO(torch.nn.Module):
@@ -104,8 +104,8 @@ class LigerLMHeadORPO(torch.nn.Module):
104
104
  self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
105
105
  self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta)
106
106
 
107
- def forward(self, x, y):
108
- return self.orpo_loss(self.lin.weight, x, y, self.lin.bias)
107
+ def forward(self, x, y, nll_target=None):
108
+ return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target)
109
109
 
110
110
 
111
111
  @pytest.mark.parametrize(
@@ -164,13 +164,15 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index,
164
164
  device=device,
165
165
  dtype=torch.long,
166
166
  )
167
+ nll_target = torch.randint(0, V, (B, T), device=device, dtype=torch.long)
168
+
167
169
  # Assign some random number of elements as ignore_index
168
170
  num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
169
171
  indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
170
172
  target.view(-1)[indices_to_assign] = ignore_index
171
173
 
172
- loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target)
173
- loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target)
174
+ loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target, nll_target)
175
+ loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target, nll_target)
174
176
 
175
177
  assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
176
178
 
@@ -244,8 +246,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias):
244
246
  bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
245
247
  bias2 = _bias.detach().clone().requires_grad_(True) if bias else None
246
248
 
247
- loss1, aggregated_aux_outputs1 = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1)
248
- loss2, aggregated_aux_outputs2 = liger_fused_linear_orpo(input2, weight2, target, bias2)
249
+ loss1, _ = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1)
250
+ loss2, _ = liger_fused_linear_orpo(input2, weight2, target, bias2)
249
251
 
250
252
  assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
251
253
 
@@ -406,8 +406,9 @@ class HFAlignmentLoss:
406
406
  _input: torch.FloatTensor,
407
407
  weight: torch.FloatTensor,
408
408
  target: torch.LongTensor,
409
- bias: torch.FloatTensor = None,
409
+ bias: torch.FloatTensor | None = None,
410
410
  average_log_prob: bool = True,
411
+ nll_target: torch.LongTensor | None = None,
411
412
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
412
413
  """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
413
414
 
@@ -430,7 +431,7 @@ class HFAlignmentLoss:
430
431
  loss = loss_fct(logits, labels)
431
432
  return loss
432
433
 
433
- labels = target
434
+ labels = nll_target if nll_target is not None else target
434
435
  chosen_nll_loss = torch.tensor(0.0, device=all_logits.device)
435
436
  if self.compute_nll_loss:
436
437
  chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
@@ -465,10 +466,11 @@ class HFAlignmentLoss:
465
466
  ref_weight: torch.FloatTensor = None,
466
467
  ref_bias: torch.FloatTensor = None,
467
468
  average_log_prob: bool = True,
469
+ nll_target: torch.LongTensor = None,
468
470
  ):
469
471
  """Compute the loss metrics for the given batch of inputs for train or test."""
470
472
 
471
- forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob)
473
+ forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, nll_target)
472
474
  (
473
475
  policy_chosen_logps,
474
476
  policy_rejected_logps,