liger-kernel-nightly 0.5.2.dev20241212030605__tar.gz → 0.5.2.dev20241212055403__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (198) hide show
  1. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/dev/modal/tests_bwd.py +1 -1
  3. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/dpo_loss.py +12 -2
  5. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/fused_linear_preference.py +42 -9
  6. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  7. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/chunked_loss/test_dpo_loss.py +28 -8
  8. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/utils.py +2 -1
  9. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.flake8 +0 -0
  10. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  11. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  12. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/pull_request_template.md +0 -0
  13. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/workflows/amd-ci.yml +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/workflows/nvi-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/workflows/publish-nightly.yml +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.github/workflows/publish-release.yml +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.gitignore +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/.isort.cfg +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/LICENSE +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/Makefile +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/NOTICE +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/README.md +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/__init__.py +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/benchmarks_visualizer.py +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/data/all_benchmark_data.csv +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/__init__.py +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_embedding.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_geglu.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_group_norm.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_jsd.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_kl_div.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_rope.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/benchmark_swiglu.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/benchmark/scripts/utils.py +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/dev/fmt-requirements.txt +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/dev/modal/tests.py +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/Acknowledgement.md +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/CONTRIBUTING.md +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/License.md +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/images/banner.GIF +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/images/compose.gif +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/images/e2e-memory.png +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/images/e2e-tps.png +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/images/logo-banner.png +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/docs/images/patch.gif +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/alignment/accelerate_config.yaml +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/alignment/run_orpo.py +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/README.md +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/callback.py +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/config/fsdp_config.json +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/img/llama_tps.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/img/qwen_tps.png +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/launch_on_modal.py +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/requirements.txt +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/run_benchmarks.sh +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/run_gemma.sh +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/run_llama.sh +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/run_qwen.sh +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/run_qwen2_vl.sh +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/training.py +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/huggingface/training_multimodal.py +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/lightning/README.md +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/lightning/requirements.txt +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/lightning/training.py +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/README.md +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/callback.py +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/medusa_util.py +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/requirements.txt +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/examples/medusa/train.py +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/licenses/LICENSE-Apache-2.0 +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/licenses/LICENSE-MIT-llmc +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/licenses/LICENSE-MIT-triton +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/setup.cfg +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/setup.py +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/__init__.py +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/README.md +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/functional.py +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/env_report.py +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/__init__.py +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/cross_entropy.py +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/geglu.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/group_norm.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/jsd.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/kl_div.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/layer_norm.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/rms_norm.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/rope.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/swiglu.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/ops/utils.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/__init__.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/auto_model.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/functional.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/geglu.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/group_norm.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/jsd.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/kl_div.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/layer_norm.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/__init__.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/gemma.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/llama.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/mistral.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/mllama.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/phi3.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/rms_norm.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/rope.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/swiglu.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/triton/__init__.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/triton/monkey_patch.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel/utils.py +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/__init__.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/chunked_loss/__init__.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/chunked_loss/test_cpo_loss.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/chunked_loss/test_orpo_loss.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/chunked_loss/test_simpo_loss.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/conftest.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/convergence/__init__.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/convergence/test_mini_models.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/convergence/test_mini_models_multimodal.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/convergence/test_mini_models_with_logits.py +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/tiny_shakespeare.txt +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_auto_model.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_cross_entropy.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_embedding.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_fused_linear_jsd.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_geglu.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_group_norm.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_jsd.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_kl_div.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_layer_norm.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_mm_int8int2.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_monkey_patch.py +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_qwen2vl_mrope.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_rms_norm.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_rope.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_swiglu.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_trainer_integration.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/transformers/test_transformers.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20241212030605 → liger_kernel_nightly-0.5.2.dev20241212055403}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241212030605
3
+ Version: 0.5.2.dev20241212055403
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -8,7 +8,7 @@ PYTHON_VERSION = "3.12"
8
8
 
9
9
  image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv")
10
10
 
11
- app = modal.App("liger_tests", image=image)
11
+ app = modal.App("liger_tests_bwd", image=image)
12
12
 
13
13
  # mount: add local files to the remote container
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20241212030605"
7
+ version = "0.5.2.dev20241212055403"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -59,6 +59,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
59
59
  weight,
60
60
  target,
61
61
  bias=None,
62
+ ref_input=None,
62
63
  ref_weight=None,
63
64
  ref_bias=None,
64
65
  ignore_index=-100,
@@ -79,6 +80,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
79
80
  compute_nll_loss=compute_nll_loss,
80
81
  compiled=compiled,
81
82
  use_ref_model=use_ref_model,
83
+ ref_input=ref_input,
82
84
  ref_weight=ref_weight,
83
85
  ref_bias=ref_bias,
84
86
  )
@@ -86,7 +88,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
86
88
  @staticmethod
87
89
  def backward(ctx, *grad_output):
88
90
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
89
- return *grads, None, None, None, None, None, None, None
91
+ return *grads, None, None, None, None, None, None, None, None
90
92
 
91
93
 
92
94
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -118,13 +120,21 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
120
  self.use_ref_model = use_ref_model
119
121
 
120
122
  def forward(
121
- self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
123
+ self,
124
+ lin_weight,
125
+ _input,
126
+ target,
127
+ bias=None,
128
+ ref_input=None,
129
+ ref_weight=None,
130
+ ref_bias=None,
122
131
  ):
123
132
  return LigerFusedLinearDPOFunction.apply(
124
133
  _input,
125
134
  lin_weight,
126
135
  target,
127
136
  bias,
137
+ ref_input,
128
138
  ref_weight,
129
139
  ref_bias,
130
140
  self.ignore_index,
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  compute_nll_loss=True,
30
30
  compiled=True,
31
31
  use_ref_model=False,
32
- # TODO: ref input
32
+ ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
35
  **loss_kwargs,
@@ -97,20 +97,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
97
97
  **loss_kwargs,
98
98
  )
99
99
 
100
- def fused_fwd_bwd(input_chunk, target_chunk):
100
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
101
101
  """
102
102
  Fused forward and backward pass for a chunk of input and target.
103
103
  """
104
104
  if bias is not None:
105
105
  return torch.func.grad_and_value(
106
106
  compute_loss, argnums=(0, 1, 3), has_aux=True
107
- )(input_chunk, weight, target_chunk, bias)
107
+ )(
108
+ input_chunk,
109
+ weight,
110
+ target_chunk,
111
+ bias,
112
+ ref_input_chunk=ref_input_chunk,
113
+ )
108
114
  else:
109
115
  return torch.func.grad_and_value(
110
116
  compute_loss, argnums=(0, 1), has_aux=True
111
- )(input_chunk, weight, target_chunk)
117
+ )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)
112
118
 
113
- def accumulate_chunk(input_chunk, target_chunk):
119
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
114
120
  if bias is not None:
115
121
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
116
122
  chunk_loss,
@@ -122,7 +128,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
122
128
  chunk_nll_loss,
123
129
  *aux_outputs,
124
130
  ),
125
- ) = fused_fwd_bwd(input_chunk, target_chunk)
131
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
126
132
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
127
133
  else:
128
134
  (chunk_grad_input, chunk_grad_weight), (
@@ -135,7 +141,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
135
141
  chunk_nll_loss,
136
142
  *aux_outputs,
137
143
  ),
138
- ) = fused_fwd_bwd(input_chunk, target_chunk)
144
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
139
145
 
140
146
  # Accumulate gradients
141
147
  grad_weight.add_(chunk_grad_weight)
@@ -182,18 +188,43 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
182
188
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
183
189
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
184
190
 
191
+ if use_ref_model:
192
+ _ref_chosen_input_chunks = torch.chunk(
193
+ ref_input[:len_chosen], chunks=chunks, dim=0
194
+ )
195
+ _ref_rejected_input_chunks = torch.chunk(
196
+ ref_input[len_chosen:], chunks=chunks, dim=0
197
+ )
198
+
185
199
  for (
186
200
  chosen_input_chunk,
187
201
  rejected_input_chunk,
188
202
  chosen_target_chunk,
189
203
  rejected_target_chunk,
204
+ ref_chosen_input_chunk,
205
+ ref_rejected_input_chunk,
190
206
  ) in zip(
191
207
  _chosen_input_chunks,
192
208
  _rejected_input_chunks,
193
209
  _chosen_target_chunks,
194
210
  _rejected_target_chunks,
211
+ (
212
+ _ref_chosen_input_chunks
213
+ if use_ref_model
214
+ else [None] * len(_chosen_input_chunks)
215
+ ),
216
+ (
217
+ _ref_rejected_input_chunks
218
+ if use_ref_model
219
+ else [None] * len(_rejected_input_chunks)
220
+ ),
195
221
  ):
196
222
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
223
+ ref_input_chunk = (
224
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
225
+ if use_ref_model
226
+ else None
227
+ )
197
228
  target_chunk = torch.cat(
198
229
  [chosen_target_chunk, rejected_target_chunk], dim=0
199
230
  )
@@ -202,9 +233,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
233
  torch._dynamo.mark_dynamic(input_chunk, 1)
203
234
  torch._dynamo.mark_dynamic(target_chunk, 1)
204
235
  torch._dynamo.mark_dynamic(target, 1)
236
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
205
237
 
206
238
  # accumulate loss, gradients, and metrics
207
- accumulate_chunk(input_chunk, target_chunk)
239
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
208
240
 
209
241
  # combine grad_chosen_inputs and grad_rejected_inputs
210
242
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -301,6 +333,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
333
  beta=0.1,
302
334
  compute_nll_loss=True,
303
335
  use_ref_model=False,
336
+ ref_input_chunk=None,
304
337
  ref_weight=None,
305
338
  ref_bias=None,
306
339
  **loss_kwargs,
@@ -357,7 +390,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
357
390
  ref_rejected_logits,
358
391
  ref_chosen_nll_loss,
359
392
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
- input_chunk,
393
+ ref_input_chunk,
361
394
  ref_weight,
362
395
  target_chunk,
363
396
  ref_bias,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241212030605
3
+ Version: 0.5.2.dev20241212055403
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -75,9 +75,15 @@ class TorchLMHeadDPO(torch.nn.Module):
75
75
  ignore_index=ignore_index, beta=beta, use_ref_model=True
76
76
  ).get_batch_loss_metrics
77
77
 
78
- def forward(self, x, y):
78
+ def forward(self, x, ref_x, y):
79
79
  return self.dpo_loss(
80
- self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
80
+ self.lin.weight,
81
+ x,
82
+ y,
83
+ self.lin.bias,
84
+ ref_x,
85
+ self.ref_lin.weight,
86
+ self.ref_lin.bias,
81
87
  )
82
88
 
83
89
 
@@ -103,9 +109,15 @@ class LigerLMHeadDPO(torch.nn.Module):
103
109
  ignore_index=ignore_index, beta=beta, use_ref_model=True
104
110
  )
105
111
 
106
- def forward(self, x, y):
112
+ def forward(self, x, ref_x, y):
107
113
  return self.dpo_loss(
108
- self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
114
+ self.lin.weight,
115
+ x,
116
+ y,
117
+ self.lin.bias,
118
+ ref_x,
119
+ self.ref_lin.weight,
120
+ self.ref_lin.bias,
109
121
  )
110
122
 
111
123
 
@@ -170,6 +182,10 @@ def test_correctness(
170
182
  input1 = _input.detach().clone().requires_grad_(True)
171
183
  input2 = _input.detach().clone().requires_grad_(True)
172
184
 
185
+ ref_input = (
186
+ torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
187
+ )
188
+
173
189
  target = torch.randint(
174
190
  0,
175
191
  V,
@@ -185,8 +201,8 @@ def test_correctness(
185
201
  indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
186
202
  target.view(-1)[indices_to_assign] = ignore_index
187
203
 
188
- loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, target)
189
- loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, target)
204
+ loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, ref_input, target)
205
+ loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, ref_input, target)
190
206
 
191
207
  assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
192
208
 
@@ -242,6 +258,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
242
258
  input1 = _input.detach().clone().requires_grad_(True)
243
259
  input2 = _input.detach().clone().requires_grad_(True)
244
260
 
261
+ ref_input = (
262
+ torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
263
+ )
264
+
245
265
  target = torch.randint(
246
266
  0,
247
267
  V,
@@ -270,10 +290,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
270
290
  ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
271
291
 
272
292
  loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply(
273
- input1, weight1, target, bias1, ref_weight1, ref_bias1
293
+ input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1
274
294
  )
275
295
  loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
276
- input2, weight2, target, bias2, ref_weight2, ref_bias2
296
+ input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2
277
297
  )
278
298
 
279
299
  assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
@@ -478,6 +478,7 @@ class HFAlignmentLoss:
478
478
  _input: torch.FloatTensor,
479
479
  target: torch.LongTensor,
480
480
  bias: torch.FloatTensor = None,
481
+ ref_input: torch.FloatTensor = None,
481
482
  ref_weight: torch.FloatTensor = None,
482
483
  ref_bias: torch.FloatTensor = None,
483
484
  average_log_prob: bool = True,
@@ -498,7 +499,7 @@ class HFAlignmentLoss:
498
499
  loss_kwargs = {}
499
500
  if self.use_ref_model:
500
501
  ref_chosen_logps, ref_rejected_logps = self.get_ref_logps(
501
- _input, ref_weight, target, ref_bias, average_log_prob
502
+ ref_input, ref_weight, target, ref_bias, average_log_prob
502
503
  )
503
504
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
504
505
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps