liger-kernel-nightly 0.5.5.dev20250331170510__tar.gz → 0.5.5.dev20250402185606__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/dev/modal/tests.py +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/dev/modal/tests_bwd.py +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel_nightly-0.5.5.dev20250402185606/src/liger_kernel/chunked_loss/fused_linear_ppo.py +330 -0
- liger_kernel_nightly-0.5.5.dev20250402185606/src/liger_kernel/chunked_loss/grpo_loss.py +236 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/cross_entropy.py +3 -2
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -1
- liger_kernel_nightly-0.5.5.dev20250402185606/test/chunked_loss/test_grpo_loss.py +470 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_orpo_loss.py +6 -0
- liger_kernel_nightly-0.5.5.dev20250331170510/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- liger_kernel_nightly-0.5.5.dev20250331170510/src/liger_kernel/chunked_loss/grpo_loss.py +0 -194
- liger_kernel_nightly-0.5.5.dev20250331170510/test/chunked_loss/test_grpo_loss.py +0 -275
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/Makefile +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/setup.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/utils.py +0 -0
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
15
15
|
|
16
16
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
|
18
18
|
def liger_tests():
|
19
19
|
import subprocess
|
20
20
|
|
@@ -14,7 +14,7 @@ app = modal.App("liger_tests_bwd", image=image)
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
15
15
|
|
16
16
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
|
18
18
|
def liger_bwd_tests():
|
19
19
|
import subprocess
|
20
20
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.5.
|
7
|
+
version = "0.5.5.dev20250402185606"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
11
12
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
12
13
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
13
14
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
15
|
+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
|
liger_kernel_nightly-0.5.5.dev20250402185606/src/liger_kernel/chunked_loss/fused_linear_ppo.py
ADDED
@@ -0,0 +1,330 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from functools import partial
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch._dynamo.config
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
class LigerFusedLinearPPOBase(torch.autograd.Function):
|
10
|
+
@abstractmethod
|
11
|
+
def ppo_loss_fn(*args, **kwargs):
|
12
|
+
"""
|
13
|
+
To be extended by subclasses.
|
14
|
+
"""
|
15
|
+
raise NotImplementedError("PPO loss function must be implemented.")
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def forward(
|
19
|
+
cls,
|
20
|
+
ctx,
|
21
|
+
_input,
|
22
|
+
weight,
|
23
|
+
selected_token_ids,
|
24
|
+
attention_mask,
|
25
|
+
advantages,
|
26
|
+
bias=None,
|
27
|
+
ref_per_token_logps=None,
|
28
|
+
old_per_token_logps=None,
|
29
|
+
ref_input=None,
|
30
|
+
ref_weight=None,
|
31
|
+
ref_bias=None,
|
32
|
+
epsilon_low=0.2,
|
33
|
+
epsilon_high=0.2,
|
34
|
+
beta=0.04,
|
35
|
+
temperature=1.0,
|
36
|
+
compiled=True,
|
37
|
+
use_ref_model=False,
|
38
|
+
chunk_size=1,
|
39
|
+
):
|
40
|
+
"""Chunked forward pass for PPO loss computation.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
cls: The class
|
44
|
+
ctx: Context for backward
|
45
|
+
_input: Input tensor
|
46
|
+
weight: Weight tensor
|
47
|
+
selected_token_ids: Selected token ids tensor
|
48
|
+
attention_mask: Attention mask tensor
|
49
|
+
advantages: Advantages tensor
|
50
|
+
bias: Bias tensor
|
51
|
+
ref_per_token_logps: Reference model log probs per token tensor
|
52
|
+
old_per_token_logps: Old per token log probabilities tensor
|
53
|
+
ref_input: Reference model input tensor
|
54
|
+
ref_weight: Reference model weight tensor
|
55
|
+
ref_bias: Reference model bias tensor
|
56
|
+
epsilon_low: Lower bound for clipping the importance sampling ratio
|
57
|
+
epsilon_high: Upper bound for clipping the importance sampling ratio
|
58
|
+
beta: Weight for the KL penalty
|
59
|
+
temperature: Temperature for the logits
|
60
|
+
compiled: Whether to use torch compile
|
61
|
+
use_ref_model: Whether to use a reference model
|
62
|
+
chunk_size: Size of chunks for processing in other loss modules
|
63
|
+
"""
|
64
|
+
if use_ref_model:
|
65
|
+
assert ref_per_token_logps is not None or ref_input is not None, (
|
66
|
+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
|
67
|
+
)
|
68
|
+
if ref_per_token_logps is not None and ref_input is not None:
|
69
|
+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
70
|
+
# Initialize accumulators
|
71
|
+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
72
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
73
|
+
grad_inputs = []
|
74
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
75
|
+
aggregated_metrics = []
|
76
|
+
|
77
|
+
# Create a partial function with fixed arguments
|
78
|
+
compute_loss = partial(
|
79
|
+
LigerFusedLinearPPOBase._compute_chunk_loss,
|
80
|
+
ref_weight=ref_weight,
|
81
|
+
ref_bias=ref_bias,
|
82
|
+
full_attention_mask=attention_mask,
|
83
|
+
epsilon_low=epsilon_low,
|
84
|
+
epsilon_high=epsilon_high,
|
85
|
+
beta=beta,
|
86
|
+
temperature=temperature,
|
87
|
+
use_ref_model=use_ref_model,
|
88
|
+
ppo_loss_fn=cls.ppo_loss_fn,
|
89
|
+
)
|
90
|
+
|
91
|
+
def fused_fwd_bwd(
|
92
|
+
input_chunk,
|
93
|
+
selected_token_ids_chunk,
|
94
|
+
attention_mask_chunk,
|
95
|
+
advantages_chunk,
|
96
|
+
ref_per_token_logps_chunk,
|
97
|
+
old_per_token_logps_chunk,
|
98
|
+
ref_input_chunk,
|
99
|
+
):
|
100
|
+
"""Fused forward and backward for a chunk."""
|
101
|
+
argnums = (0, 1, 5) if bias is not None else (0, 1)
|
102
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
103
|
+
input_chunk, # arg 0
|
104
|
+
weight, # arg 1
|
105
|
+
selected_token_ids_chunk, # arg 2
|
106
|
+
attention_mask_chunk, # arg 3
|
107
|
+
advantages_chunk, # arg 4
|
108
|
+
bias, # arg 5
|
109
|
+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
|
110
|
+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
|
111
|
+
ref_input_chunk=ref_input_chunk, # arg 8
|
112
|
+
)
|
113
|
+
|
114
|
+
def accumulate_chunk(
|
115
|
+
input_chunk,
|
116
|
+
selected_token_ids_chunk,
|
117
|
+
attention_mask_chunk,
|
118
|
+
advantages_chunk,
|
119
|
+
ref_per_token_logps_chunk=None,
|
120
|
+
old_per_token_logps_chunk=None,
|
121
|
+
ref_input_chunk=None,
|
122
|
+
):
|
123
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
124
|
+
input_chunk,
|
125
|
+
selected_token_ids_chunk,
|
126
|
+
attention_mask_chunk,
|
127
|
+
advantages_chunk,
|
128
|
+
ref_per_token_logps_chunk,
|
129
|
+
old_per_token_logps_chunk,
|
130
|
+
ref_input_chunk,
|
131
|
+
)
|
132
|
+
if bias is not None:
|
133
|
+
grad_bias.add_(chunk_grad_bias[0])
|
134
|
+
|
135
|
+
# Accumulate gradients and loss
|
136
|
+
grad_weight.add_(chunk_grad_weight)
|
137
|
+
grad_inputs.append(chunk_grad_input)
|
138
|
+
loss_acc.add_(chunk_loss)
|
139
|
+
# Initialize storage for metrics on first chunk
|
140
|
+
if len(aggregated_metrics) == 0:
|
141
|
+
for metric in chunk_metrics:
|
142
|
+
if metric.ndim == 0:
|
143
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
144
|
+
else:
|
145
|
+
aggregated_metrics.append([])
|
146
|
+
|
147
|
+
# Accumulate metrics
|
148
|
+
for i, metric in enumerate(chunk_metrics):
|
149
|
+
if metric.ndim == 0:
|
150
|
+
aggregated_metrics[i].add_(metric)
|
151
|
+
else:
|
152
|
+
aggregated_metrics[i].append(metric)
|
153
|
+
|
154
|
+
if compiled:
|
155
|
+
# TODO: Figure out what is better to compile here
|
156
|
+
# accumulate_chunk = torch.compile(accumulate_chunk)
|
157
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
158
|
+
|
159
|
+
# Process input in chunks based on chunk_size
|
160
|
+
chunks = max(1, _input.shape[0] // chunk_size)
|
161
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
162
|
+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
|
163
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
164
|
+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
|
165
|
+
_ref_per_token_logps_chunks = (
|
166
|
+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
|
167
|
+
if use_ref_model and ref_per_token_logps is not None
|
168
|
+
else [None] * chunks
|
169
|
+
)
|
170
|
+
_old_per_token_logps_chunks = (
|
171
|
+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
|
172
|
+
if old_per_token_logps is not None
|
173
|
+
else [None] * chunks
|
174
|
+
)
|
175
|
+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
|
176
|
+
_ref_input_chunks = (
|
177
|
+
torch.chunk(ref_input, chunks=chunks, dim=0)
|
178
|
+
if use_ref_model and ref_per_token_logps is None
|
179
|
+
else [None] * chunks
|
180
|
+
)
|
181
|
+
|
182
|
+
for (
|
183
|
+
input_chunk,
|
184
|
+
selected_token_ids_chunk,
|
185
|
+
attention_mask_chunk,
|
186
|
+
advantages_chunk,
|
187
|
+
ref_per_token_logps_chunk,
|
188
|
+
old_per_token_logps_chunk,
|
189
|
+
ref_input_chunk,
|
190
|
+
) in zip(
|
191
|
+
_input_chunks,
|
192
|
+
_selected_token_ids_chunks,
|
193
|
+
_attention_mask_chunks,
|
194
|
+
_advantages_chunks,
|
195
|
+
_ref_per_token_logps_chunks,
|
196
|
+
_old_per_token_logps_chunks,
|
197
|
+
_ref_input_chunks,
|
198
|
+
):
|
199
|
+
# Mark dynamic dimensions
|
200
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
201
|
+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
|
202
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
203
|
+
if ref_per_token_logps_chunk is not None:
|
204
|
+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
|
205
|
+
if ref_input_chunk is not None:
|
206
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
207
|
+
if old_per_token_logps_chunk is not None:
|
208
|
+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
|
209
|
+
|
210
|
+
accumulate_chunk(
|
211
|
+
input_chunk,
|
212
|
+
selected_token_ids_chunk,
|
213
|
+
attention_mask_chunk,
|
214
|
+
advantages_chunk,
|
215
|
+
ref_per_token_logps_chunk,
|
216
|
+
old_per_token_logps_chunk,
|
217
|
+
ref_input_chunk,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Combine gradients
|
221
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
222
|
+
|
223
|
+
# Save for backward
|
224
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
225
|
+
|
226
|
+
# Finalize metrics
|
227
|
+
final_metrics = []
|
228
|
+
for metric in aggregated_metrics:
|
229
|
+
if isinstance(metric, list):
|
230
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
231
|
+
else:
|
232
|
+
final_metrics.append(metric)
|
233
|
+
|
234
|
+
return loss_acc, tuple(final_metrics)
|
235
|
+
|
236
|
+
@staticmethod
|
237
|
+
def _compute_chunk_loss(
|
238
|
+
input_chunk,
|
239
|
+
weight,
|
240
|
+
selected_token_ids_chunk,
|
241
|
+
attention_mask_chunk,
|
242
|
+
advantages_chunk,
|
243
|
+
bias=None,
|
244
|
+
ref_per_token_logps_chunk=None,
|
245
|
+
old_per_token_logps_chunk=None,
|
246
|
+
ref_input_chunk=None,
|
247
|
+
ref_weight=None,
|
248
|
+
ref_bias=None,
|
249
|
+
full_attention_mask=None,
|
250
|
+
epsilon_low=0.2,
|
251
|
+
epsilon_high=0.2,
|
252
|
+
beta=0.04,
|
253
|
+
temperature=1.0,
|
254
|
+
use_ref_model=False,
|
255
|
+
ppo_loss_fn=None,
|
256
|
+
):
|
257
|
+
"""Compute loss for a single chunk."""
|
258
|
+
# Get policy log probabilities using chunk_forward
|
259
|
+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
|
260
|
+
|
261
|
+
# Get reference log probabilities if needed
|
262
|
+
ref_log_probs = None
|
263
|
+
if use_ref_model and ref_per_token_logps_chunk is None:
|
264
|
+
with torch.no_grad():
|
265
|
+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
|
266
|
+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
|
267
|
+
)
|
268
|
+
|
269
|
+
# Compute chunk loss and metrics using the provided loss function
|
270
|
+
chunk_loss, chunk_metrics = ppo_loss_fn(
|
271
|
+
log_probs=log_probs,
|
272
|
+
selected_token_ids=selected_token_ids_chunk,
|
273
|
+
attention_mask=attention_mask_chunk,
|
274
|
+
advantages=advantages_chunk,
|
275
|
+
full_attention_mask=full_attention_mask,
|
276
|
+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
|
277
|
+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
|
278
|
+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
|
279
|
+
epsilon_low=epsilon_low,
|
280
|
+
epsilon_high=epsilon_high,
|
281
|
+
beta=beta,
|
282
|
+
)
|
283
|
+
|
284
|
+
return chunk_loss, chunk_metrics
|
285
|
+
|
286
|
+
@staticmethod
|
287
|
+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
|
288
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
289
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
290
|
+
logits = torch.matmul(input_chunk, weight.t())
|
291
|
+
if bias is not None:
|
292
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
293
|
+
if temperature != 1.0:
|
294
|
+
logits = logits / temperature
|
295
|
+
|
296
|
+
# Compute log probabilities using softmax over the last dimension
|
297
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
298
|
+
|
299
|
+
return log_probs, logits
|
300
|
+
|
301
|
+
@staticmethod
|
302
|
+
def backward(ctx, grad_output, *grad_metrics):
|
303
|
+
"""Backward pass for PPO loss."""
|
304
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
305
|
+
if grad_output != 1.0:
|
306
|
+
grad_input = grad_input * grad_output
|
307
|
+
grad_weight = grad_weight * grad_output
|
308
|
+
if grad_bias is not None:
|
309
|
+
grad_bias = grad_bias * grad_output
|
310
|
+
|
311
|
+
return (
|
312
|
+
grad_input,
|
313
|
+
grad_weight,
|
314
|
+
None, # grad_selected_token_ids
|
315
|
+
None, # grad_attention_mask
|
316
|
+
None, # grad_advantages
|
317
|
+
grad_bias,
|
318
|
+
None, # grad_ref_per_token_logps
|
319
|
+
None, # grad_old_per_token_logps
|
320
|
+
None, # grad_ref_input
|
321
|
+
None, # grad_ref_weight
|
322
|
+
None, # grad_ref_bias
|
323
|
+
None, # grad_epsilon_low
|
324
|
+
None, # grad_epsilon_high
|
325
|
+
None, # grad_beta
|
326
|
+
None, # grad_temperature
|
327
|
+
None, # grad_compiled
|
328
|
+
None, # grad_use_ref_model
|
329
|
+
None, # grad_chunk_size
|
330
|
+
)
|
@@ -0,0 +1,236 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
4
|
+
|
5
|
+
|
6
|
+
def k3_loss_fn(log_p, log_q):
|
7
|
+
# computes k3 estimate of KL[q, p]
|
8
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
9
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
10
|
+
|
11
|
+
|
12
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
13
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
14
|
+
|
15
|
+
|
16
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
17
|
+
@staticmethod
|
18
|
+
def ppo_loss_fn(
|
19
|
+
log_probs,
|
20
|
+
selected_token_ids,
|
21
|
+
attention_mask,
|
22
|
+
advantages,
|
23
|
+
full_attention_mask,
|
24
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
25
|
+
old_per_token_logps=None,
|
26
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
27
|
+
epsilon_low=0.2,
|
28
|
+
epsilon_high=0.2,
|
29
|
+
beta=0.04,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
33
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
34
|
+
-1
|
35
|
+
) # (batch_size, seq_len)
|
36
|
+
|
37
|
+
# Get reference model probabilities
|
38
|
+
if ref_per_token_logps is None:
|
39
|
+
if ref_log_probs is not None:
|
40
|
+
with torch.no_grad():
|
41
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
42
|
+
-1
|
43
|
+
)
|
44
|
+
else:
|
45
|
+
ref_per_token_logps = per_token_logps.detach()
|
46
|
+
|
47
|
+
# Compute policy gradient loss with importance sampling ratio
|
48
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
49
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
50
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
51
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
52
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
53
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
54
|
+
if beta != 0.0:
|
55
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
56
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
57
|
+
# Combine losses
|
58
|
+
per_token_loss = per_token_loss + beta * kl_div
|
59
|
+
|
60
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
61
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
62
|
+
# and TRL GRPO implementation
|
63
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
64
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
65
|
+
|
66
|
+
# Calculate metrics
|
67
|
+
metrics = []
|
68
|
+
if beta != 0.0:
|
69
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
70
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
71
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
72
|
+
)
|
73
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
74
|
+
return loss, metrics
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def forward(
|
78
|
+
cls,
|
79
|
+
ctx,
|
80
|
+
_input,
|
81
|
+
weight,
|
82
|
+
selected_token_ids,
|
83
|
+
attention_mask,
|
84
|
+
advantages,
|
85
|
+
bias=None,
|
86
|
+
ref_per_token_logps=None,
|
87
|
+
old_per_token_logps=None,
|
88
|
+
ref_input=None,
|
89
|
+
ref_weight=None,
|
90
|
+
ref_bias=None,
|
91
|
+
beta=0.04,
|
92
|
+
epsilon_low=0.2,
|
93
|
+
epsilon_high=0.2,
|
94
|
+
temperature=1.0,
|
95
|
+
compiled=True,
|
96
|
+
use_ref_model=True,
|
97
|
+
chunk_size=1,
|
98
|
+
):
|
99
|
+
"""
|
100
|
+
Fused linear layer with GRPO loss.
|
101
|
+
Args:
|
102
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
103
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
104
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
105
|
+
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
106
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
107
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
108
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
109
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
110
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
111
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
112
|
+
beta (float): Weight for the KL penalty
|
113
|
+
temperature (float): Temperature for the logits
|
114
|
+
compiled (bool): Whether to use torch compile
|
115
|
+
use_ref_model (bool): Whether to use a reference model
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
117
|
+
Returns:
|
118
|
+
torch.Tensor: Computed loss
|
119
|
+
"""
|
120
|
+
return super().forward(
|
121
|
+
cls=cls,
|
122
|
+
ctx=ctx,
|
123
|
+
_input=_input,
|
124
|
+
weight=weight,
|
125
|
+
selected_token_ids=selected_token_ids,
|
126
|
+
attention_mask=attention_mask,
|
127
|
+
advantages=advantages,
|
128
|
+
bias=bias,
|
129
|
+
ref_per_token_logps=ref_per_token_logps,
|
130
|
+
old_per_token_logps=old_per_token_logps,
|
131
|
+
ref_input=ref_input,
|
132
|
+
ref_weight=ref_weight,
|
133
|
+
ref_bias=ref_bias,
|
134
|
+
beta=beta,
|
135
|
+
epsilon_low=epsilon_low,
|
136
|
+
epsilon_high=epsilon_high,
|
137
|
+
temperature=temperature,
|
138
|
+
compiled=compiled,
|
139
|
+
use_ref_model=use_ref_model,
|
140
|
+
chunk_size=chunk_size,
|
141
|
+
)
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def backward(ctx, grad_output, *grad_metrics):
|
145
|
+
"""Backward pass for GRPO loss.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
grad_output: Gradient of the loss (scalar)
|
149
|
+
grad_metrics: Gradients of the metrics (not used in backward computation)
|
150
|
+
"""
|
151
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
152
|
+
return (
|
153
|
+
*grads[
|
154
|
+
:6
|
155
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
156
|
+
None, # grad_ref_per_token_logps
|
157
|
+
None, # grad_old_per_token_logps
|
158
|
+
None, # grad_ref_input
|
159
|
+
None, # grad_ref_weight
|
160
|
+
None, # grad_ref_bias
|
161
|
+
None, # grad_beta
|
162
|
+
None, # grad_epsilon_low
|
163
|
+
None, # grad_epsilon_high
|
164
|
+
None, # grad_temperature
|
165
|
+
None, # grad_compiled
|
166
|
+
None, # grad_use_ref_model
|
167
|
+
None, # grad_chunk_size
|
168
|
+
)
|
169
|
+
|
170
|
+
|
171
|
+
class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
172
|
+
"""Fused linear layer with GRPO loss."""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
beta: float = 0.04,
|
177
|
+
compiled: bool = True,
|
178
|
+
use_ref_model: bool = True,
|
179
|
+
chunk_size: int = 1,
|
180
|
+
epsilon_low: float = 0.2,
|
181
|
+
epsilon_high: float = 0.2,
|
182
|
+
temperature: float = 1.0,
|
183
|
+
):
|
184
|
+
"""
|
185
|
+
Args:
|
186
|
+
beta (float): Weight for the KL penalty.
|
187
|
+
compiled (bool): Whether to use torch compile.
|
188
|
+
use_ref_model (bool): Whether to use a reference model.
|
189
|
+
chunk_size (int): Size of chunks for processing.
|
190
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
191
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
192
|
+
temperature (float): Temperature for the logits.
|
193
|
+
"""
|
194
|
+
super().__init__()
|
195
|
+
self.beta = beta
|
196
|
+
self.compiled = compiled
|
197
|
+
self.use_ref_model = use_ref_model
|
198
|
+
self.chunk_size = chunk_size
|
199
|
+
self.epsilon_low = epsilon_low
|
200
|
+
self.epsilon_high = epsilon_high
|
201
|
+
self.temperature = temperature
|
202
|
+
|
203
|
+
def forward(
|
204
|
+
self,
|
205
|
+
_input,
|
206
|
+
lin_weight,
|
207
|
+
selected_token_ids,
|
208
|
+
attention_mask,
|
209
|
+
advantages,
|
210
|
+
bias=None,
|
211
|
+
ref_per_token_logps=None,
|
212
|
+
old_per_token_logps=None,
|
213
|
+
ref_input=None,
|
214
|
+
ref_weight=None,
|
215
|
+
ref_bias=None,
|
216
|
+
):
|
217
|
+
return LigerFusedLinearGRPOFunction.apply(
|
218
|
+
_input,
|
219
|
+
lin_weight,
|
220
|
+
selected_token_ids,
|
221
|
+
attention_mask,
|
222
|
+
advantages,
|
223
|
+
bias,
|
224
|
+
ref_per_token_logps,
|
225
|
+
old_per_token_logps,
|
226
|
+
ref_input,
|
227
|
+
ref_weight,
|
228
|
+
ref_bias,
|
229
|
+
self.beta,
|
230
|
+
self.epsilon_low,
|
231
|
+
self.epsilon_high,
|
232
|
+
self.temperature,
|
233
|
+
self.compiled,
|
234
|
+
self.use_ref_model,
|
235
|
+
self.chunk_size,
|
236
|
+
)
|
@@ -9,6 +9,7 @@ import triton.language as tl
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
12
|
+
from liger_kernel.utils import infer_device
|
12
13
|
|
13
14
|
if compare_version("triton", operator.ge, "3.0.0"):
|
14
15
|
try:
|
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
|
|
59
60
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
60
61
|
loss_stride (int): The stride of the loss tensor.
|
61
62
|
n_cols (int): The number of columns in the input tensor.
|
62
|
-
n_non_ignore (
|
63
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
63
64
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
64
65
|
weight_sum (float): The sum of weight tensor.
|
65
66
|
ignore_index (int): The index to ignore in the target.
|
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
|
|
258
259
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
259
260
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
260
261
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
261
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
262
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
262
263
|
|
263
264
|
|
264
265
|
def cross_entropy_forward(
|
@@ -112,8 +112,8 @@ src/liger_kernel/chunked_loss/cpo_loss.py
|
|
112
112
|
src/liger_kernel/chunked_loss/dpo_loss.py
|
113
113
|
src/liger_kernel/chunked_loss/functional.py
|
114
114
|
src/liger_kernel/chunked_loss/fused_linear_distillation.py
|
115
|
+
src/liger_kernel/chunked_loss/fused_linear_ppo.py
|
115
116
|
src/liger_kernel/chunked_loss/fused_linear_preference.py
|
116
|
-
src/liger_kernel/chunked_loss/fused_linear_rlhf.py
|
117
117
|
src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py
|
118
118
|
src/liger_kernel/chunked_loss/grpo_loss.py
|
119
119
|
src/liger_kernel/chunked_loss/jsd_loss.py
|