liger-kernel-nightly 0.5.2.dev20241220004933__tar.gz → 0.5.2.dev20241220220835__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/cpo_loss.py +15 -3
  4. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/simpo_loss.py +18 -3
  5. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  6. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/chunked_loss/test_cpo_loss.py +23 -2
  7. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/chunked_loss/test_simpo_loss.py +22 -2
  8. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.flake8 +0 -0
  9. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  10. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  11. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/pull_request_template.md +0 -0
  12. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/workflows/amd-ci.yml +0 -0
  13. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/workflows/nvi-ci.yml +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/workflows/publish-nightly.yml +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.github/workflows/publish-release.yml +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.gitignore +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/.isort.cfg +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/LICENSE +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/Makefile +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/NOTICE +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/README.md +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/__init__.py +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/benchmarks_visualizer.py +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/data/all_benchmark_data.csv +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/__init__.py +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_embedding.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_geglu.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_group_norm.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_jsd.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_kl_div.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_rope.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/benchmark_swiglu.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/benchmark/scripts/utils.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/dev/fmt-requirements.txt +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/dev/modal/tests.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/dev/modal/tests_bwd.py +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/Acknowledgement.md +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/CONTRIBUTING.md +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/License.md +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/banner.GIF +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/compose.gif +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/e2e-memory.png +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/e2e-tps.png +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/logo-banner.png +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/patch.gif +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/docs/images/post-training.png +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/alignment/accelerate_config.yaml +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/alignment/run_orpo.py +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/README.md +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/callback.py +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/config/fsdp_config.json +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/img/llama_tps.png +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/img/qwen_tps.png +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/launch_on_modal.py +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/requirements.txt +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/run_benchmarks.sh +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/run_gemma.sh +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/run_llama.sh +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/run_qwen.sh +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/run_qwen2_vl.sh +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/training.py +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/huggingface/training_multimodal.py +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/lightning/README.md +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/lightning/requirements.txt +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/lightning/training.py +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/README.md +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/callback.py +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/medusa_util.py +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/requirements.txt +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/examples/medusa/train.py +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/licenses/LICENSE-Apache-2.0 +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/licenses/LICENSE-MIT-llmc +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/licenses/LICENSE-MIT-triton +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/setup.cfg +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/setup.py +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/__init__.py +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/README.md +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/functional.py +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/env_report.py +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/__init__.py +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/cross_entropy.py +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/geglu.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/group_norm.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/jsd.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/kl_div.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/layer_norm.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/rms_norm.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/rope.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/swiglu.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/ops/utils.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/__init__.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/auto_model.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/functional.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/geglu.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/group_norm.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/jsd.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/kl_div.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/layer_norm.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/__init__.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/gemma.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/llama.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/mistral.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/mllama.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/phi3.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/rms_norm.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/rope.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/swiglu.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/triton/__init__.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/triton/monkey_patch.py +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel/utils.py +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/__init__.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/chunked_loss/__init__.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/chunked_loss/test_dpo_loss.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/chunked_loss/test_orpo_loss.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/conftest.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/convergence/__init__.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/convergence/test_mini_models.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/convergence/test_mini_models_multimodal.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/convergence/test_mini_models_with_logits.py +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/tiny_shakespeare.txt +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_auto_model.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_cross_entropy.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_embedding.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_fused_linear_jsd.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_geglu.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_group_norm.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_jsd.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_kl_div.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_layer_norm.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_mm_int8int2.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_monkey_patch.py +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_qwen2vl_mrope.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_rms_norm.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_rope.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_swiglu.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_trainer_integration.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/transformers/test_transformers.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/triton/test_triton_monkey_patch.py +0 -0
  199. {liger_kernel_nightly-0.5.2.dev20241220004933 → liger_kernel_nightly-0.5.2.dev20241220220835}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241220004933
3
+ Version: 0.5.2.dev20241220220835
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20241220004933"
7
+ version = "0.5.2.dev20241220220835"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -9,7 +9,9 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
9
9
  class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
12
+ def preference_loss_fn(
13
+ chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
14
+ ):
13
15
  """
14
16
  Paper: https://arxiv.org/pdf/2401.08417
15
17
 
@@ -30,9 +32,14 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
30
32
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
33
  full_target (torch.Tensor): Non chunked full target tensor
32
34
  beta (float): Weight for the CPO loss
35
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
33
36
  """
34
37
  logits = beta * (chosen_logps - rejected_logps)
35
- loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
38
+ loss = (
39
+ F.logsigmoid(logits) * (1 - label_smoothing)
40
+ + F.logsigmoid(-logits) * label_smoothing
41
+ ).sum() / (full_target.shape[0] // 2)
42
+
36
43
  return loss
37
44
 
38
45
  @staticmethod
@@ -45,6 +52,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
45
52
  ignore_index=-100,
46
53
  beta=0.1,
47
54
  alpha=1.0,
55
+ label_smoothing=0.0,
48
56
  compute_nll_loss=True,
49
57
  compiled=True,
50
58
  ):
@@ -58,6 +66,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
58
66
  ignore_index=ignore_index,
59
67
  alpha=alpha,
60
68
  beta=beta,
69
+ label_smoothing=label_smoothing,
61
70
  compute_nll_loss=compute_nll_loss,
62
71
  compiled=compiled,
63
72
  )
@@ -65,7 +74,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
65
74
  @staticmethod
66
75
  def backward(ctx, *grad_output):
67
76
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
68
- return *grads, None, None, None, None, None
77
+ return *grads, None, None, None, None, None, None
69
78
 
70
79
 
71
80
  class LigerFusedLinearCPOLoss(torch.nn.Module):
@@ -78,6 +87,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
78
87
  ignore_index: int = -100,
79
88
  beta: float = 0.1,
80
89
  alpha: float = 1.0,
90
+ label_smoothing: float = 0.0,
81
91
  compute_nll_loss: bool = True,
82
92
  compiled: bool = True,
83
93
  ):
@@ -90,6 +100,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
90
100
  self.ignore_index = ignore_index
91
101
  self.beta = beta
92
102
  self.alpha = alpha
103
+ self.label_smoothing = label_smoothing
93
104
  self.compute_nll_loss = compute_nll_loss
94
105
  self.compiled = compiled
95
106
 
@@ -102,6 +113,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
102
113
  self.ignore_index,
103
114
  self.beta,
104
115
  self.alpha,
116
+ self.label_smoothing,
105
117
  self.compute_nll_loss,
106
118
  self.compiled,
107
119
  )
@@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
12
  def preference_loss_fn(
13
- chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
13
+ chosen_logps,
14
+ rejected_logps,
15
+ full_target,
16
+ beta=0.1,
17
+ gamma=0.5,
18
+ label_smoothing=0.0,
14
19
  ):
15
20
  """
16
21
  Paper: https://arxiv.org/pdf/2405.14734
@@ -33,9 +38,14 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
33
38
  full_target: Non chunked full target tensor
34
39
  beta (float): beta weight
35
40
  gamma (float): gemma margin term
41
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
36
42
  """
37
43
  logits = beta * (chosen_logps - rejected_logps) - gamma
38
- loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
44
+ loss = (
45
+ F.logsigmoid(logits) * (1 - label_smoothing)
46
+ + F.logsigmoid(-logits) * label_smoothing
47
+ ).sum() / (full_target.shape[0] // 2)
48
+
39
49
  return loss
40
50
 
41
51
  @staticmethod
@@ -48,6 +58,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
48
58
  ignore_index=-100,
49
59
  beta=0.1,
50
60
  alpha=1.0,
61
+ label_smoothing=0.0,
51
62
  compute_nll_loss=False,
52
63
  compiled=True,
53
64
  gamma=0.5,
@@ -63,6 +74,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
63
74
  ignore_index=ignore_index,
64
75
  alpha=alpha,
65
76
  beta=beta,
77
+ label_smoothing=label_smoothing,
66
78
  compiled=compiled,
67
79
  gamma=gamma,
68
80
  )
@@ -70,7 +82,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
70
82
  @staticmethod
71
83
  def backward(ctx, *grad_output):
72
84
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
- return *grads, None, None, None, None, None, None
85
+ return *grads, None, None, None, None, None, None, None
74
86
 
75
87
 
76
88
  class LigerFusedLinearSimPOLoss(torch.nn.Module):
@@ -83,6 +95,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
83
95
  ignore_index: int = -100,
84
96
  beta: float = 0.1,
85
97
  alpha: float = 1.0,
98
+ label_smoothing: float = 0.0,
86
99
  compute_nll_loss: bool = True,
87
100
  compiled: bool = True,
88
101
  gamma: float = 0.5,
@@ -96,6 +109,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
96
109
  self.ignore_index = ignore_index
97
110
  self.beta = beta
98
111
  self.alpha = alpha
112
+ self.label_smoothing = label_smoothing
99
113
  self.compute_nll_loss = compute_nll_loss
100
114
  self.compiled = compiled
101
115
  self.gamma = gamma
@@ -109,6 +123,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
109
123
  self.ignore_index,
110
124
  self.beta,
111
125
  self.alpha,
126
+ self.label_smoothing,
112
127
  self.compute_nll_loss,
113
128
  self.compiled,
114
129
  self.gamma,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241220004933
3
+ Version: 0.5.2.dev20241220220835
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -86,6 +86,7 @@ class TorchLMHeadCPO(torch.nn.Module):
86
86
  ignore_index: int = -100,
87
87
  beta: float = 0.1,
88
88
  alpha: float = 1.0,
89
+ label_smoothing: float = 0.0,
89
90
  loss_type: str = "sigmoid",
90
91
  simpo_gamma: float = 0.5,
91
92
  ):
@@ -97,6 +98,7 @@ class TorchLMHeadCPO(torch.nn.Module):
97
98
  ignore_index=ignore_index,
98
99
  beta=beta,
99
100
  loss_type=loss_type,
101
+ label_smoothing=label_smoothing,
100
102
  simpo_gamma=simpo_gamma,
101
103
  ).get_batch_loss_metrics
102
104
 
@@ -114,13 +116,17 @@ class LigerLMHeadCPO(torch.nn.Module):
114
116
  ignore_index: int = -100,
115
117
  beta: float = 0.1,
116
118
  alpha: float = 1.0,
119
+ label_smoothing: float = 0.0,
117
120
  ):
118
121
  super().__init__()
119
122
  self.lin = torch.nn.Linear(
120
123
  in_features=H, out_features=V, bias=bias, dtype=dtype
121
124
  )
122
125
  self.cpo_loss = LigerFusedLinearCPOLoss(
123
- ignore_index=ignore_index, beta=beta, alpha=alpha
126
+ ignore_index=ignore_index,
127
+ beta=beta,
128
+ alpha=alpha,
129
+ label_smoothing=label_smoothing,
124
130
  )
125
131
 
126
132
  def forward(self, x, y):
@@ -145,8 +151,21 @@ class LigerLMHeadCPO(torch.nn.Module):
145
151
  @pytest.mark.parametrize(
146
152
  "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]
147
153
  )
154
+ @pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
148
155
  def test_correctness(
149
- B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha
156
+ B,
157
+ T,
158
+ H,
159
+ V,
160
+ scalar,
161
+ dtype,
162
+ atol,
163
+ rtol,
164
+ bias,
165
+ ignore_index,
166
+ beta,
167
+ alpha,
168
+ label_smoothing,
150
169
  ):
151
170
  B = 2 * B # cpo loss requires B to be even
152
171
 
@@ -157,6 +176,7 @@ def test_correctness(
157
176
  bias=bias,
158
177
  ignore_index=ignore_index,
159
178
  beta=beta,
179
+ label_smoothing=label_smoothing,
160
180
  )
161
181
  liger_lm_head_cpo = LigerLMHeadCPO(
162
182
  H=H,
@@ -165,6 +185,7 @@ def test_correctness(
165
185
  bias=bias,
166
186
  ignore_index=ignore_index,
167
187
  beta=beta,
188
+ label_smoothing=label_smoothing,
168
189
  )
169
190
 
170
191
  torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(
@@ -25,6 +25,7 @@ class LigerLMHeadSimPO(torch.nn.Module):
25
25
  ignore_index: int = -100,
26
26
  beta: float = 0.1,
27
27
  alpha: float = 1.0,
28
+ label_smoothing: float = 0.0,
28
29
  gamma: float = 0.5,
29
30
  ):
30
31
  super().__init__()
@@ -32,7 +33,11 @@ class LigerLMHeadSimPO(torch.nn.Module):
32
33
  in_features=H, out_features=V, bias=bias, dtype=dtype
33
34
  )
34
35
  self.simpo_loss = LigerFusedLinearSimPOLoss(
35
- ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma
36
+ ignore_index=ignore_index,
37
+ beta=beta,
38
+ alpha=alpha,
39
+ gamma=gamma,
40
+ label_smoothing=label_smoothing,
36
41
  )
37
42
 
38
43
  def forward(self, x, y):
@@ -57,8 +62,21 @@ class LigerLMHeadSimPO(torch.nn.Module):
57
62
  @pytest.mark.parametrize(
58
63
  "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)]
59
64
  )
65
+ @pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
60
66
  def test_correctness(
61
- B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma
67
+ B,
68
+ T,
69
+ H,
70
+ V,
71
+ scalar,
72
+ dtype,
73
+ atol,
74
+ rtol,
75
+ bias,
76
+ ignore_index,
77
+ beta,
78
+ gamma,
79
+ label_smoothing,
62
80
  ):
63
81
  B = 2 * B # SimPO loss requires B to be even
64
82
 
@@ -70,6 +88,7 @@ def test_correctness(
70
88
  ignore_index=ignore_index,
71
89
  beta=beta,
72
90
  loss_type="simpo",
91
+ label_smoothing=label_smoothing,
73
92
  simpo_gamma=gamma,
74
93
  )
75
94
  liger_lm_head_simpo = LigerLMHeadSimPO(
@@ -79,6 +98,7 @@ def test_correctness(
79
98
  bias=bias,
80
99
  ignore_index=ignore_index,
81
100
  beta=beta,
101
+ label_smoothing=label_smoothing,
82
102
  gamma=gamma,
83
103
  )
84
104