liger-kernel-nightly 0.5.5.dev20250331042257__tar.gz → 0.5.5.dev20250402184001__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/dev/modal/tests.py +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/dev/modal/tests_bwd.py +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel_nightly-0.5.5.dev20250402184001/src/liger_kernel/chunked_loss/fused_linear_ppo.py +330 -0
- liger_kernel_nightly-0.5.5.dev20250402184001/src/liger_kernel/chunked_loss/grpo_loss.py +236 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/llava.py +20 -34
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -1
- liger_kernel_nightly-0.5.5.dev20250402184001/test/chunked_loss/test_grpo_loss.py +470 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/test_orpo_loss.py +6 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/bf16/test_mini_models_multimodal.py +1 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/fp32/test_mini_models_multimodal.py +1 -0
- liger_kernel_nightly-0.5.5.dev20250331042257/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- liger_kernel_nightly-0.5.5.dev20250331042257/src/liger_kernel/chunked_loss/grpo_loss.py +0 -194
- liger_kernel_nightly-0.5.5.dev20250331042257/test/chunked_loss/test_grpo_loss.py +0 -275
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/Makefile +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/setup.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257 → liger_kernel_nightly-0.5.5.dev20250402184001}/test/utils.py +0 -0
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
15
15
|
|
16
16
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
|
18
18
|
def liger_tests():
|
19
19
|
import subprocess
|
20
20
|
|
@@ -14,7 +14,7 @@ app = modal.App("liger_tests_bwd", image=image)
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
15
15
|
|
16
16
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
|
18
18
|
def liger_bwd_tests():
|
19
19
|
import subprocess
|
20
20
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.5.
|
7
|
+
version = "0.5.5.dev20250402184001"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
11
12
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
12
13
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
13
14
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
15
|
+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
|
liger_kernel_nightly-0.5.5.dev20250402184001/src/liger_kernel/chunked_loss/fused_linear_ppo.py
ADDED
@@ -0,0 +1,330 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from functools import partial
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch._dynamo.config
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
class LigerFusedLinearPPOBase(torch.autograd.Function):
|
10
|
+
@abstractmethod
|
11
|
+
def ppo_loss_fn(*args, **kwargs):
|
12
|
+
"""
|
13
|
+
To be extended by subclasses.
|
14
|
+
"""
|
15
|
+
raise NotImplementedError("PPO loss function must be implemented.")
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def forward(
|
19
|
+
cls,
|
20
|
+
ctx,
|
21
|
+
_input,
|
22
|
+
weight,
|
23
|
+
selected_token_ids,
|
24
|
+
attention_mask,
|
25
|
+
advantages,
|
26
|
+
bias=None,
|
27
|
+
ref_per_token_logps=None,
|
28
|
+
old_per_token_logps=None,
|
29
|
+
ref_input=None,
|
30
|
+
ref_weight=None,
|
31
|
+
ref_bias=None,
|
32
|
+
epsilon_low=0.2,
|
33
|
+
epsilon_high=0.2,
|
34
|
+
beta=0.04,
|
35
|
+
temperature=1.0,
|
36
|
+
compiled=True,
|
37
|
+
use_ref_model=False,
|
38
|
+
chunk_size=1,
|
39
|
+
):
|
40
|
+
"""Chunked forward pass for PPO loss computation.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
cls: The class
|
44
|
+
ctx: Context for backward
|
45
|
+
_input: Input tensor
|
46
|
+
weight: Weight tensor
|
47
|
+
selected_token_ids: Selected token ids tensor
|
48
|
+
attention_mask: Attention mask tensor
|
49
|
+
advantages: Advantages tensor
|
50
|
+
bias: Bias tensor
|
51
|
+
ref_per_token_logps: Reference model log probs per token tensor
|
52
|
+
old_per_token_logps: Old per token log probabilities tensor
|
53
|
+
ref_input: Reference model input tensor
|
54
|
+
ref_weight: Reference model weight tensor
|
55
|
+
ref_bias: Reference model bias tensor
|
56
|
+
epsilon_low: Lower bound for clipping the importance sampling ratio
|
57
|
+
epsilon_high: Upper bound for clipping the importance sampling ratio
|
58
|
+
beta: Weight for the KL penalty
|
59
|
+
temperature: Temperature for the logits
|
60
|
+
compiled: Whether to use torch compile
|
61
|
+
use_ref_model: Whether to use a reference model
|
62
|
+
chunk_size: Size of chunks for processing in other loss modules
|
63
|
+
"""
|
64
|
+
if use_ref_model:
|
65
|
+
assert ref_per_token_logps is not None or ref_input is not None, (
|
66
|
+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
|
67
|
+
)
|
68
|
+
if ref_per_token_logps is not None and ref_input is not None:
|
69
|
+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
70
|
+
# Initialize accumulators
|
71
|
+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
72
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
73
|
+
grad_inputs = []
|
74
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
75
|
+
aggregated_metrics = []
|
76
|
+
|
77
|
+
# Create a partial function with fixed arguments
|
78
|
+
compute_loss = partial(
|
79
|
+
LigerFusedLinearPPOBase._compute_chunk_loss,
|
80
|
+
ref_weight=ref_weight,
|
81
|
+
ref_bias=ref_bias,
|
82
|
+
full_attention_mask=attention_mask,
|
83
|
+
epsilon_low=epsilon_low,
|
84
|
+
epsilon_high=epsilon_high,
|
85
|
+
beta=beta,
|
86
|
+
temperature=temperature,
|
87
|
+
use_ref_model=use_ref_model,
|
88
|
+
ppo_loss_fn=cls.ppo_loss_fn,
|
89
|
+
)
|
90
|
+
|
91
|
+
def fused_fwd_bwd(
|
92
|
+
input_chunk,
|
93
|
+
selected_token_ids_chunk,
|
94
|
+
attention_mask_chunk,
|
95
|
+
advantages_chunk,
|
96
|
+
ref_per_token_logps_chunk,
|
97
|
+
old_per_token_logps_chunk,
|
98
|
+
ref_input_chunk,
|
99
|
+
):
|
100
|
+
"""Fused forward and backward for a chunk."""
|
101
|
+
argnums = (0, 1, 5) if bias is not None else (0, 1)
|
102
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
103
|
+
input_chunk, # arg 0
|
104
|
+
weight, # arg 1
|
105
|
+
selected_token_ids_chunk, # arg 2
|
106
|
+
attention_mask_chunk, # arg 3
|
107
|
+
advantages_chunk, # arg 4
|
108
|
+
bias, # arg 5
|
109
|
+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
|
110
|
+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
|
111
|
+
ref_input_chunk=ref_input_chunk, # arg 8
|
112
|
+
)
|
113
|
+
|
114
|
+
def accumulate_chunk(
|
115
|
+
input_chunk,
|
116
|
+
selected_token_ids_chunk,
|
117
|
+
attention_mask_chunk,
|
118
|
+
advantages_chunk,
|
119
|
+
ref_per_token_logps_chunk=None,
|
120
|
+
old_per_token_logps_chunk=None,
|
121
|
+
ref_input_chunk=None,
|
122
|
+
):
|
123
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
124
|
+
input_chunk,
|
125
|
+
selected_token_ids_chunk,
|
126
|
+
attention_mask_chunk,
|
127
|
+
advantages_chunk,
|
128
|
+
ref_per_token_logps_chunk,
|
129
|
+
old_per_token_logps_chunk,
|
130
|
+
ref_input_chunk,
|
131
|
+
)
|
132
|
+
if bias is not None:
|
133
|
+
grad_bias.add_(chunk_grad_bias[0])
|
134
|
+
|
135
|
+
# Accumulate gradients and loss
|
136
|
+
grad_weight.add_(chunk_grad_weight)
|
137
|
+
grad_inputs.append(chunk_grad_input)
|
138
|
+
loss_acc.add_(chunk_loss)
|
139
|
+
# Initialize storage for metrics on first chunk
|
140
|
+
if len(aggregated_metrics) == 0:
|
141
|
+
for metric in chunk_metrics:
|
142
|
+
if metric.ndim == 0:
|
143
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
144
|
+
else:
|
145
|
+
aggregated_metrics.append([])
|
146
|
+
|
147
|
+
# Accumulate metrics
|
148
|
+
for i, metric in enumerate(chunk_metrics):
|
149
|
+
if metric.ndim == 0:
|
150
|
+
aggregated_metrics[i].add_(metric)
|
151
|
+
else:
|
152
|
+
aggregated_metrics[i].append(metric)
|
153
|
+
|
154
|
+
if compiled:
|
155
|
+
# TODO: Figure out what is better to compile here
|
156
|
+
# accumulate_chunk = torch.compile(accumulate_chunk)
|
157
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
158
|
+
|
159
|
+
# Process input in chunks based on chunk_size
|
160
|
+
chunks = max(1, _input.shape[0] // chunk_size)
|
161
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
162
|
+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
|
163
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
164
|
+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
|
165
|
+
_ref_per_token_logps_chunks = (
|
166
|
+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
|
167
|
+
if use_ref_model and ref_per_token_logps is not None
|
168
|
+
else [None] * chunks
|
169
|
+
)
|
170
|
+
_old_per_token_logps_chunks = (
|
171
|
+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
|
172
|
+
if old_per_token_logps is not None
|
173
|
+
else [None] * chunks
|
174
|
+
)
|
175
|
+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
|
176
|
+
_ref_input_chunks = (
|
177
|
+
torch.chunk(ref_input, chunks=chunks, dim=0)
|
178
|
+
if use_ref_model and ref_per_token_logps is None
|
179
|
+
else [None] * chunks
|
180
|
+
)
|
181
|
+
|
182
|
+
for (
|
183
|
+
input_chunk,
|
184
|
+
selected_token_ids_chunk,
|
185
|
+
attention_mask_chunk,
|
186
|
+
advantages_chunk,
|
187
|
+
ref_per_token_logps_chunk,
|
188
|
+
old_per_token_logps_chunk,
|
189
|
+
ref_input_chunk,
|
190
|
+
) in zip(
|
191
|
+
_input_chunks,
|
192
|
+
_selected_token_ids_chunks,
|
193
|
+
_attention_mask_chunks,
|
194
|
+
_advantages_chunks,
|
195
|
+
_ref_per_token_logps_chunks,
|
196
|
+
_old_per_token_logps_chunks,
|
197
|
+
_ref_input_chunks,
|
198
|
+
):
|
199
|
+
# Mark dynamic dimensions
|
200
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
201
|
+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
|
202
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
203
|
+
if ref_per_token_logps_chunk is not None:
|
204
|
+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
|
205
|
+
if ref_input_chunk is not None:
|
206
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
207
|
+
if old_per_token_logps_chunk is not None:
|
208
|
+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
|
209
|
+
|
210
|
+
accumulate_chunk(
|
211
|
+
input_chunk,
|
212
|
+
selected_token_ids_chunk,
|
213
|
+
attention_mask_chunk,
|
214
|
+
advantages_chunk,
|
215
|
+
ref_per_token_logps_chunk,
|
216
|
+
old_per_token_logps_chunk,
|
217
|
+
ref_input_chunk,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Combine gradients
|
221
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
222
|
+
|
223
|
+
# Save for backward
|
224
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
225
|
+
|
226
|
+
# Finalize metrics
|
227
|
+
final_metrics = []
|
228
|
+
for metric in aggregated_metrics:
|
229
|
+
if isinstance(metric, list):
|
230
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
231
|
+
else:
|
232
|
+
final_metrics.append(metric)
|
233
|
+
|
234
|
+
return loss_acc, tuple(final_metrics)
|
235
|
+
|
236
|
+
@staticmethod
|
237
|
+
def _compute_chunk_loss(
|
238
|
+
input_chunk,
|
239
|
+
weight,
|
240
|
+
selected_token_ids_chunk,
|
241
|
+
attention_mask_chunk,
|
242
|
+
advantages_chunk,
|
243
|
+
bias=None,
|
244
|
+
ref_per_token_logps_chunk=None,
|
245
|
+
old_per_token_logps_chunk=None,
|
246
|
+
ref_input_chunk=None,
|
247
|
+
ref_weight=None,
|
248
|
+
ref_bias=None,
|
249
|
+
full_attention_mask=None,
|
250
|
+
epsilon_low=0.2,
|
251
|
+
epsilon_high=0.2,
|
252
|
+
beta=0.04,
|
253
|
+
temperature=1.0,
|
254
|
+
use_ref_model=False,
|
255
|
+
ppo_loss_fn=None,
|
256
|
+
):
|
257
|
+
"""Compute loss for a single chunk."""
|
258
|
+
# Get policy log probabilities using chunk_forward
|
259
|
+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
|
260
|
+
|
261
|
+
# Get reference log probabilities if needed
|
262
|
+
ref_log_probs = None
|
263
|
+
if use_ref_model and ref_per_token_logps_chunk is None:
|
264
|
+
with torch.no_grad():
|
265
|
+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
|
266
|
+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
|
267
|
+
)
|
268
|
+
|
269
|
+
# Compute chunk loss and metrics using the provided loss function
|
270
|
+
chunk_loss, chunk_metrics = ppo_loss_fn(
|
271
|
+
log_probs=log_probs,
|
272
|
+
selected_token_ids=selected_token_ids_chunk,
|
273
|
+
attention_mask=attention_mask_chunk,
|
274
|
+
advantages=advantages_chunk,
|
275
|
+
full_attention_mask=full_attention_mask,
|
276
|
+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
|
277
|
+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
|
278
|
+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
|
279
|
+
epsilon_low=epsilon_low,
|
280
|
+
epsilon_high=epsilon_high,
|
281
|
+
beta=beta,
|
282
|
+
)
|
283
|
+
|
284
|
+
return chunk_loss, chunk_metrics
|
285
|
+
|
286
|
+
@staticmethod
|
287
|
+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
|
288
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
289
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
290
|
+
logits = torch.matmul(input_chunk, weight.t())
|
291
|
+
if bias is not None:
|
292
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
293
|
+
if temperature != 1.0:
|
294
|
+
logits = logits / temperature
|
295
|
+
|
296
|
+
# Compute log probabilities using softmax over the last dimension
|
297
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
298
|
+
|
299
|
+
return log_probs, logits
|
300
|
+
|
301
|
+
@staticmethod
|
302
|
+
def backward(ctx, grad_output, *grad_metrics):
|
303
|
+
"""Backward pass for PPO loss."""
|
304
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
305
|
+
if grad_output != 1.0:
|
306
|
+
grad_input = grad_input * grad_output
|
307
|
+
grad_weight = grad_weight * grad_output
|
308
|
+
if grad_bias is not None:
|
309
|
+
grad_bias = grad_bias * grad_output
|
310
|
+
|
311
|
+
return (
|
312
|
+
grad_input,
|
313
|
+
grad_weight,
|
314
|
+
None, # grad_selected_token_ids
|
315
|
+
None, # grad_attention_mask
|
316
|
+
None, # grad_advantages
|
317
|
+
grad_bias,
|
318
|
+
None, # grad_ref_per_token_logps
|
319
|
+
None, # grad_old_per_token_logps
|
320
|
+
None, # grad_ref_input
|
321
|
+
None, # grad_ref_weight
|
322
|
+
None, # grad_ref_bias
|
323
|
+
None, # grad_epsilon_low
|
324
|
+
None, # grad_epsilon_high
|
325
|
+
None, # grad_beta
|
326
|
+
None, # grad_temperature
|
327
|
+
None, # grad_compiled
|
328
|
+
None, # grad_use_ref_model
|
329
|
+
None, # grad_chunk_size
|
330
|
+
)
|
@@ -0,0 +1,236 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
4
|
+
|
5
|
+
|
6
|
+
def k3_loss_fn(log_p, log_q):
|
7
|
+
# computes k3 estimate of KL[q, p]
|
8
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
9
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
10
|
+
|
11
|
+
|
12
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
13
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
14
|
+
|
15
|
+
|
16
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
17
|
+
@staticmethod
|
18
|
+
def ppo_loss_fn(
|
19
|
+
log_probs,
|
20
|
+
selected_token_ids,
|
21
|
+
attention_mask,
|
22
|
+
advantages,
|
23
|
+
full_attention_mask,
|
24
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
25
|
+
old_per_token_logps=None,
|
26
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
27
|
+
epsilon_low=0.2,
|
28
|
+
epsilon_high=0.2,
|
29
|
+
beta=0.04,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
33
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
34
|
+
-1
|
35
|
+
) # (batch_size, seq_len)
|
36
|
+
|
37
|
+
# Get reference model probabilities
|
38
|
+
if ref_per_token_logps is None:
|
39
|
+
if ref_log_probs is not None:
|
40
|
+
with torch.no_grad():
|
41
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
42
|
+
-1
|
43
|
+
)
|
44
|
+
else:
|
45
|
+
ref_per_token_logps = per_token_logps.detach()
|
46
|
+
|
47
|
+
# Compute policy gradient loss with importance sampling ratio
|
48
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
49
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
50
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
51
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
52
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
53
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
54
|
+
if beta != 0.0:
|
55
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
56
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
57
|
+
# Combine losses
|
58
|
+
per_token_loss = per_token_loss + beta * kl_div
|
59
|
+
|
60
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
61
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
62
|
+
# and TRL GRPO implementation
|
63
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
64
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
65
|
+
|
66
|
+
# Calculate metrics
|
67
|
+
metrics = []
|
68
|
+
if beta != 0.0:
|
69
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
70
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
71
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
72
|
+
)
|
73
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
74
|
+
return loss, metrics
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def forward(
|
78
|
+
cls,
|
79
|
+
ctx,
|
80
|
+
_input,
|
81
|
+
weight,
|
82
|
+
selected_token_ids,
|
83
|
+
attention_mask,
|
84
|
+
advantages,
|
85
|
+
bias=None,
|
86
|
+
ref_per_token_logps=None,
|
87
|
+
old_per_token_logps=None,
|
88
|
+
ref_input=None,
|
89
|
+
ref_weight=None,
|
90
|
+
ref_bias=None,
|
91
|
+
beta=0.04,
|
92
|
+
epsilon_low=0.2,
|
93
|
+
epsilon_high=0.2,
|
94
|
+
temperature=1.0,
|
95
|
+
compiled=True,
|
96
|
+
use_ref_model=True,
|
97
|
+
chunk_size=1,
|
98
|
+
):
|
99
|
+
"""
|
100
|
+
Fused linear layer with GRPO loss.
|
101
|
+
Args:
|
102
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
103
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
104
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
105
|
+
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
106
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
107
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
108
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
109
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
110
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
111
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
112
|
+
beta (float): Weight for the KL penalty
|
113
|
+
temperature (float): Temperature for the logits
|
114
|
+
compiled (bool): Whether to use torch compile
|
115
|
+
use_ref_model (bool): Whether to use a reference model
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
117
|
+
Returns:
|
118
|
+
torch.Tensor: Computed loss
|
119
|
+
"""
|
120
|
+
return super().forward(
|
121
|
+
cls=cls,
|
122
|
+
ctx=ctx,
|
123
|
+
_input=_input,
|
124
|
+
weight=weight,
|
125
|
+
selected_token_ids=selected_token_ids,
|
126
|
+
attention_mask=attention_mask,
|
127
|
+
advantages=advantages,
|
128
|
+
bias=bias,
|
129
|
+
ref_per_token_logps=ref_per_token_logps,
|
130
|
+
old_per_token_logps=old_per_token_logps,
|
131
|
+
ref_input=ref_input,
|
132
|
+
ref_weight=ref_weight,
|
133
|
+
ref_bias=ref_bias,
|
134
|
+
beta=beta,
|
135
|
+
epsilon_low=epsilon_low,
|
136
|
+
epsilon_high=epsilon_high,
|
137
|
+
temperature=temperature,
|
138
|
+
compiled=compiled,
|
139
|
+
use_ref_model=use_ref_model,
|
140
|
+
chunk_size=chunk_size,
|
141
|
+
)
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def backward(ctx, grad_output, *grad_metrics):
|
145
|
+
"""Backward pass for GRPO loss.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
grad_output: Gradient of the loss (scalar)
|
149
|
+
grad_metrics: Gradients of the metrics (not used in backward computation)
|
150
|
+
"""
|
151
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
152
|
+
return (
|
153
|
+
*grads[
|
154
|
+
:6
|
155
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
156
|
+
None, # grad_ref_per_token_logps
|
157
|
+
None, # grad_old_per_token_logps
|
158
|
+
None, # grad_ref_input
|
159
|
+
None, # grad_ref_weight
|
160
|
+
None, # grad_ref_bias
|
161
|
+
None, # grad_beta
|
162
|
+
None, # grad_epsilon_low
|
163
|
+
None, # grad_epsilon_high
|
164
|
+
None, # grad_temperature
|
165
|
+
None, # grad_compiled
|
166
|
+
None, # grad_use_ref_model
|
167
|
+
None, # grad_chunk_size
|
168
|
+
)
|
169
|
+
|
170
|
+
|
171
|
+
class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
172
|
+
"""Fused linear layer with GRPO loss."""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
beta: float = 0.04,
|
177
|
+
compiled: bool = True,
|
178
|
+
use_ref_model: bool = True,
|
179
|
+
chunk_size: int = 1,
|
180
|
+
epsilon_low: float = 0.2,
|
181
|
+
epsilon_high: float = 0.2,
|
182
|
+
temperature: float = 1.0,
|
183
|
+
):
|
184
|
+
"""
|
185
|
+
Args:
|
186
|
+
beta (float): Weight for the KL penalty.
|
187
|
+
compiled (bool): Whether to use torch compile.
|
188
|
+
use_ref_model (bool): Whether to use a reference model.
|
189
|
+
chunk_size (int): Size of chunks for processing.
|
190
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
191
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
192
|
+
temperature (float): Temperature for the logits.
|
193
|
+
"""
|
194
|
+
super().__init__()
|
195
|
+
self.beta = beta
|
196
|
+
self.compiled = compiled
|
197
|
+
self.use_ref_model = use_ref_model
|
198
|
+
self.chunk_size = chunk_size
|
199
|
+
self.epsilon_low = epsilon_low
|
200
|
+
self.epsilon_high = epsilon_high
|
201
|
+
self.temperature = temperature
|
202
|
+
|
203
|
+
def forward(
|
204
|
+
self,
|
205
|
+
_input,
|
206
|
+
lin_weight,
|
207
|
+
selected_token_ids,
|
208
|
+
attention_mask,
|
209
|
+
advantages,
|
210
|
+
bias=None,
|
211
|
+
ref_per_token_logps=None,
|
212
|
+
old_per_token_logps=None,
|
213
|
+
ref_input=None,
|
214
|
+
ref_weight=None,
|
215
|
+
ref_bias=None,
|
216
|
+
):
|
217
|
+
return LigerFusedLinearGRPOFunction.apply(
|
218
|
+
_input,
|
219
|
+
lin_weight,
|
220
|
+
selected_token_ids,
|
221
|
+
attention_mask,
|
222
|
+
advantages,
|
223
|
+
bias,
|
224
|
+
ref_per_token_logps,
|
225
|
+
old_per_token_logps,
|
226
|
+
ref_input,
|
227
|
+
ref_weight,
|
228
|
+
ref_bias,
|
229
|
+
self.beta,
|
230
|
+
self.epsilon_low,
|
231
|
+
self.epsilon_high,
|
232
|
+
self.temperature,
|
233
|
+
self.compiled,
|
234
|
+
self.use_ref_model,
|
235
|
+
self.chunk_size,
|
236
|
+
)
|