liger-kernel-nightly 0.5.2.dev20250108072837__tar.gz → 0.5.2.dev20250108102127__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (197) hide show
  1. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/cpo_loss.py +1 -0
  4. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/fused_linear_preference.py +14 -3
  5. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/cross_entropy.py +8 -24
  6. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/fused_linear_cross_entropy.py +4 -4
  7. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/cross_entropy.py +0 -3
  8. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  9. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/chunked_loss/test_cpo_loss.py +3 -2
  10. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  11. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  12. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/pull_request_template.md +0 -0
  13. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/workflows/amd-ci.yml +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/workflows/nvi-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/workflows/publish-nightly.yml +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.github/workflows/publish-release.yml +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/.gitignore +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/LICENSE +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/Makefile +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/NOTICE +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/README.md +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/__init__.py +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/benchmarks_visualizer.py +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/data/all_benchmark_data.csv +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/__init__.py +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_embedding.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_geglu.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_group_norm.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_jsd.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_kl_div.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_rope.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/benchmark_swiglu.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/benchmark/scripts/utils.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/dev/fmt-requirements.txt +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/dev/modal/tests.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/dev/modal/tests_bwd.py +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/Acknowledgement.md +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/CONTRIBUTING.md +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/License.md +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/banner.GIF +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/compose.gif +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/e2e-memory.png +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/e2e-tps.png +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/logo-banner.png +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/patch.gif +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/docs/images/post-training.png +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/alignment/accelerate_config.yaml +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/alignment/run_orpo.py +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/README.md +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/callback.py +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/config/fsdp_config.json +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/img/llama_tps.png +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/img/qwen_tps.png +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/launch_on_modal.py +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/requirements.txt +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/run_benchmarks.sh +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/run_gemma.sh +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/run_llama.sh +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/run_qwen.sh +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/run_qwen2_vl.sh +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/training.py +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/huggingface/training_multimodal.py +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/lightning/README.md +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/lightning/requirements.txt +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/lightning/training.py +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/README.md +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/callback.py +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/medusa_util.py +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/requirements.txt +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/examples/medusa/train.py +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/licenses/LICENSE-Apache-2.0 +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/licenses/LICENSE-MIT-llmc +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/licenses/LICENSE-MIT-triton +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/setup.cfg +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/setup.py +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/__init__.py +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/README.md +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/functional.py +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/env_report.py +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/__init__.py +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/geglu.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/group_norm.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/jsd.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/kl_div.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/layer_norm.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/rms_norm.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/rope.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/swiglu.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/ops/utils.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/__init__.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/auto_model.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/functional.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/geglu.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/group_norm.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/jsd.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/kl_div.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/layer_norm.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/__init__.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/gemma.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/llama.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/mistral.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/mllama.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/phi3.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/rms_norm.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/rope.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/swiglu.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/triton/__init__.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/triton/monkey_patch.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel/utils.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/__init__.py +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/chunked_loss/__init__.py +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/chunked_loss/test_dpo_loss.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/chunked_loss/test_orpo_loss.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/chunked_loss/test_simpo_loss.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/conftest.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/convergence/__init__.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/convergence/test_mini_models.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/convergence/test_mini_models_multimodal.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/convergence/test_mini_models_with_logits.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/tiny_shakespeare.txt +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_auto_model.py +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_cross_entropy.py +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_embedding.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_fused_linear_jsd.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_geglu.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_group_norm.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_jsd.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_kl_div.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_layer_norm.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_mm_int8int2.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_monkey_patch.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_qwen2vl_mrope.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_rms_norm.py +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_rope.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_swiglu.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_trainer_integration.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/transformers/test_transformers.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/triton/test_triton_monkey_patch.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20250108072837 → liger_kernel_nightly-0.5.2.dev20250108102127}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108072837
3
+ Version: 0.5.2.dev20250108102127
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20250108072837"
7
+ version = "0.5.2.dev20250108102127"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -65,6 +65,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
65
65
  beta=beta,
66
66
  label_smoothing=label_smoothing,
67
67
  compute_nll_loss=compute_nll_loss,
68
+ average_log_prob=False,
68
69
  compiled=compiled,
69
70
  )
70
71
 
@@ -32,6 +32,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
32
32
  ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
+ average_log_prob=True,
35
36
  **loss_kwargs,
36
37
  ):
37
38
  """
@@ -61,6 +62,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
61
62
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
63
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
64
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
65
+ average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
64
66
  loss_kwargs (dict): Other possible arguments that a loss function might need
65
67
  """
66
68
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -94,6 +96,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
94
96
  use_ref_model=use_ref_model,
95
97
  ref_weight=ref_weight,
96
98
  ref_bias=ref_bias,
99
+ average_log_prob=average_log_prob,
97
100
  **loss_kwargs,
98
101
  )
99
102
 
@@ -265,6 +268,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
265
268
  bias=None,
266
269
  ignore_index=-100,
267
270
  compute_nll_loss=True,
271
+ average_log_prob=True,
268
272
  ):
269
273
  len_chosen_chunk = target_chunk.shape[0] // 2
270
274
  logits_chunk = input_chunk @ weight.t()
@@ -285,10 +289,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
285
289
  label_chunk = torch.where(loss_mask, target_chunk, 0)
286
290
 
287
291
  per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
288
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
292
+ if average_log_prob:
293
+ log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
294
+ else:
295
+ log_prob = (per_token_logps * loss_mask).sum(-1)
289
296
 
290
- chosen_logps = average_log_prob[:len_chosen_chunk]
291
- rejected_logps = average_log_prob[len_chosen_chunk:]
297
+ chosen_logps = log_prob[:len_chosen_chunk]
298
+ rejected_logps = log_prob[len_chosen_chunk:]
292
299
 
293
300
  chosen_logits = logits_chunk[:len_chosen_chunk]
294
301
  rejected_logits = logits_chunk[len_chosen_chunk:]
@@ -317,6 +324,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
317
324
  ref_input_chunk=None,
318
325
  ref_weight=None,
319
326
  ref_bias=None,
327
+ average_log_prob=True,
320
328
  **loss_kwargs,
321
329
  ):
322
330
  """
@@ -335,6 +343,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
335
343
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
336
344
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
337
345
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
346
+ average_log_prob (bool): Whether to average log probabilities or the sum.
338
347
  loss_kwargs (dict): Additional arguments for the loss function.
339
348
  """
340
349
  (
@@ -350,6 +359,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
350
359
  bias=bias,
351
360
  ignore_index=ignore_index,
352
361
  compute_nll_loss=compute_nll_loss,
362
+ average_log_prob=average_log_prob,
353
363
  )
354
364
  chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
355
365
  chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
@@ -372,6 +382,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
372
382
  ref_bias,
373
383
  ignore_index=ignore_index,
374
384
  compute_nll_loss=False, # We don't need NLL loss for the reference model
385
+ average_log_prob=average_log_prob,
375
386
  )
376
387
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
377
388
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
@@ -20,9 +20,6 @@ if compare_version("triton", operator.ge, "3.0.0"):
20
20
  else:
21
21
  from triton.language.math import tanh
22
22
 
23
- _TRUE: tl.constexpr = tl.constexpr(1)
24
- _FALSE: tl.constexpr = tl.constexpr(0)
25
-
26
23
 
27
24
  @triton.jit
28
25
  def liger_cross_entropy_kernel(
@@ -95,7 +92,7 @@ def liger_cross_entropy_kernel(
95
92
  return
96
93
 
97
94
  loss_ptr += program_id * loss_stride
98
- if RETURN_Z_LOSS == _TRUE:
95
+ if RETURN_Z_LOSS:
99
96
  z_loss_ptr += program_id * loss_stride
100
97
 
101
98
  if HAS_WEIGHT:
@@ -254,7 +251,7 @@ def liger_cross_entropy_kernel(
254
251
  loss += z_loss
255
252
 
256
253
  tl.store(loss_ptr, loss)
257
- if RETURN_Z_LOSS == _TRUE:
254
+ if RETURN_Z_LOSS:
258
255
  tl.store(z_loss_ptr, z_loss)
259
256
 
260
257
 
@@ -264,12 +261,6 @@ def liger_cross_entropy_kernel(
264
261
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
265
262
 
266
263
 
267
- _bool_to_return_z_loss = {
268
- True: _TRUE.value,
269
- False: _FALSE.value,
270
- }
271
-
272
-
273
264
  def cross_entropy_forward(
274
265
  _input,
275
266
  target,
@@ -281,11 +272,7 @@ def cross_entropy_forward(
281
272
  softcap,
282
273
  return_z_loss,
283
274
  ):
284
- if not isinstance(return_z_loss, int):
285
- assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
286
- return_z_loss = _bool_to_return_z_loss[return_z_loss]
287
- else:
288
- assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
275
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
289
276
 
290
277
  BT, V = _input.shape
291
278
  n_rows = BT
@@ -294,10 +281,7 @@ def cross_entropy_forward(
294
281
 
295
282
  # unreduced loss
296
283
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
297
- if return_z_loss == _TRUE.value:
298
- z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
299
- else:
300
- z_loss_1d = None # set None when return_z_loss == False
284
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
301
285
 
302
286
  target_mask = target != ignore_index
303
287
  n_non_ignore = target_mask.sum().item()
@@ -326,7 +310,7 @@ def cross_entropy_forward(
326
310
  X_stride=_input.stride(-2),
327
311
  Y_ptr=target,
328
312
  Y_stride=target.stride(-1), # always 1
329
- weight_ptr=weight if weight is not None else _input, # dummy if None
313
+ weight_ptr=weight, # dummy if None
330
314
  loss_ptr=loss_1d,
331
315
  z_loss_ptr=z_loss_1d,
332
316
  loss_stride=loss_1d.stride(-1), # always 1
@@ -338,7 +322,7 @@ def cross_entropy_forward(
338
322
  lse_square_scale=lse_square_scale,
339
323
  label_smoothing=label_smoothing,
340
324
  reduction=reduction,
341
- softcap=softcap if softcap is not None else 0.0,
325
+ softcap=softcap,
342
326
  RETURN_Z_LOSS=return_z_loss,
343
327
  BLOCK_SIZE=BLOCK_SIZE,
344
328
  HAS_WEIGHT=True if weight is not None else False,
@@ -350,10 +334,10 @@ def cross_entropy_forward(
350
334
 
351
335
  if reduction == "none":
352
336
  loss = loss_1d
353
- z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
337
+ z_loss = z_loss_1d if return_z_loss else None
354
338
  else:
355
339
  loss = torch.sum(loss_1d)
356
- z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
340
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
357
341
 
358
342
  return loss, z_loss, _input
359
343
 
@@ -92,9 +92,9 @@ def fused_linear_cross_entropy_forward(
92
92
  X_stride=logits_chunk.stride(-2),
93
93
  Y_ptr=target_chunk,
94
94
  Y_stride=target_chunk.stride(-1), # always 1
95
- weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
95
+ weight_ptr=ce_weight,
96
96
  loss_ptr=loss_1d_slice,
97
- z_loss_ptr=loss_1d_slice, # dummy ptr, not used
97
+ z_loss_ptr=None,
98
98
  loss_stride=loss_1d_slice.stride(-1), # always 1
99
99
  n_cols=V,
100
100
  n_non_ignore=total_n_non_ignore,
@@ -104,8 +104,8 @@ def fused_linear_cross_entropy_forward(
104
104
  lse_square_scale=lse_square_scale,
105
105
  label_smoothing=label_smoothing,
106
106
  reduction=reduction,
107
- softcap=softcap if softcap is not None else 0.0,
108
- RETURN_Z_LOSS=0, # False
107
+ softcap=softcap,
108
+ RETURN_Z_LOSS=False,
109
109
  HAS_WEIGHT=True if ce_weight is not None else False,
110
110
  HAS_SOFTCAPPING=True if softcap is not None else False,
111
111
  BLOCK_SIZE=BLOCK_SIZE,
@@ -20,9 +20,6 @@ class LigerCrossEntropyLoss(torch.nn.Module):
20
20
  assert (label_smoothing >= 0) and (
21
21
  label_smoothing <= 1
22
22
  ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
23
- assert (label_smoothing >= 0) and (
24
- label_smoothing <= 1
25
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
26
23
  assert reduction in {
27
24
  "mean",
28
25
  "sum",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108072837
3
+ Version: 0.5.2.dev20250108102127
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -103,9 +103,10 @@ class TorchLMHeadCPO(torch.nn.Module):
103
103
  label_smoothing=label_smoothing,
104
104
  simpo_gamma=simpo_gamma,
105
105
  ).get_batch_loss_metrics
106
+ self.average_log_prob = loss_type == "simpo"
106
107
 
107
108
  def forward(self, x, y):
108
- return self.cpo_loss(self.lin.weight, x, y, self.lin.bias)
109
+ return self.cpo_loss(self.lin.weight, x, y, self.lin.bias, average_log_prob=self.average_log_prob)
109
110
 
110
111
 
111
112
  class LigerLMHeadCPO(torch.nn.Module):
@@ -143,7 +144,7 @@ class LigerLMHeadCPO(torch.nn.Module):
143
144
  @pytest.mark.parametrize(
144
145
  "scalar, dtype, atol, rtol",
145
146
  [
146
- (1.0, torch.bfloat16, 5e-3, 5e-3),
147
+ (1.0, torch.bfloat16, 5e-2, 5e-2),
147
148
  (1.0, torch.float32, 1e-5, 5e-4),
148
149
  ],
149
150
  )