liger-kernel-nightly 0.5.6.dev20250411201510__tar.gz → 0.5.6.dev20250411224032__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +15 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/grpo_loss.py +33 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/gemma2.py +1 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/gemma3.py +1 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/loss_utils.py +17 -10
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_grpo_loss.py +35 -3
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/Makefile +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/README.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/setup.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/gema3_rms.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/test/utils.py +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.6.
|
7
|
+
version = "0.5.6.dev20250411224032"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
32
32
|
epsilon_low=0.2,
|
33
33
|
epsilon_high=0.2,
|
34
34
|
beta=0.04,
|
35
|
+
loss_type="bnpo",
|
36
|
+
max_completion_length=None,
|
35
37
|
temperature=1.0,
|
36
38
|
compiled=True,
|
37
39
|
use_ref_model=False,
|
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
57
59
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
58
60
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
59
61
|
beta: Weight for the KL penalty
|
62
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
63
|
+
max_completion_length: Maximum completion length required for "dr_grpo"
|
60
64
|
temperature: Temperature for the logits
|
61
65
|
compiled: Whether to use torch compile
|
62
66
|
use_ref_model: Whether to use a reference model
|
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
68
72
|
)
|
69
73
|
if ref_per_token_logps is not None and ref_input is not None:
|
70
74
|
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
75
|
+
if loss_type == "dr_grpo":
|
76
|
+
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
|
71
77
|
# Initialize accumulators
|
72
78
|
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
73
79
|
grad_weight = torch.zeros_like(weight) # [V, H]
|
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
84
90
|
epsilon_low=epsilon_low,
|
85
91
|
epsilon_high=epsilon_high,
|
86
92
|
beta=beta,
|
93
|
+
loss_type=loss_type,
|
94
|
+
max_completion_length=max_completion_length,
|
87
95
|
temperature=temperature,
|
88
96
|
use_ref_model=use_ref_model,
|
89
97
|
ppo_loss_fn=cls.ppo_loss_fn,
|
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
251
259
|
epsilon_low=0.2,
|
252
260
|
epsilon_high=0.2,
|
253
261
|
beta=0.04,
|
262
|
+
loss_type="bnpo",
|
263
|
+
max_completion_length=None,
|
254
264
|
temperature=1.0,
|
255
265
|
use_ref_model=False,
|
256
266
|
ppo_loss_fn=None,
|
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
280
290
|
epsilon_low=epsilon_low,
|
281
291
|
epsilon_high=epsilon_high,
|
282
292
|
beta=beta,
|
293
|
+
loss_type=loss_type,
|
294
|
+
max_completion_length=max_completion_length,
|
283
295
|
)
|
284
296
|
|
285
297
|
return chunk_loss, chunk_metrics
|
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
303
315
|
def backward(ctx, grad_output, *grad_metrics):
|
304
316
|
"""Backward pass for PPO loss."""
|
305
317
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
318
|
+
|
306
319
|
if grad_output != 1.0:
|
307
320
|
grad_input = grad_input * grad_output
|
308
321
|
grad_weight = grad_weight * grad_output
|
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
328
341
|
None, # grad_compiled
|
329
342
|
None, # grad_use_ref_model
|
330
343
|
None, # grad_chunk_size
|
344
|
+
None, # grad_loss_type
|
345
|
+
None, # grad_max_completion_length
|
331
346
|
)
|
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
27
27
|
epsilon_low=0.2,
|
28
28
|
epsilon_high=0.2,
|
29
29
|
beta=0.04,
|
30
|
+
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
31
|
+
max_completion_length=None, # Required for dr_grpo
|
30
32
|
**kwargs,
|
31
33
|
):
|
32
34
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
61
63
|
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
62
64
|
# and TRL GRPO implementation
|
63
65
|
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
64
|
-
|
66
|
+
if loss_type == "grpo":
|
67
|
+
# Average per-sequence loss
|
68
|
+
loss = (
|
69
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
70
|
+
).sum() / full_attention_mask.shape[0]
|
71
|
+
elif loss_type == "bnpo":
|
72
|
+
# Batch Normalized Per-token loss (original implementation)
|
73
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
74
|
+
elif loss_type == "dr_grpo":
|
75
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
76
|
+
if max_completion_length is None:
|
77
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
78
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
79
|
+
else:
|
80
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
65
81
|
|
66
82
|
# Calculate metrics
|
67
83
|
metrics = []
|
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
91
107
|
beta=0.04,
|
92
108
|
epsilon_low=0.2,
|
93
109
|
epsilon_high=0.2,
|
110
|
+
loss_type="bnpo",
|
111
|
+
max_completion_length=None,
|
94
112
|
temperature=1.0,
|
95
113
|
compiled=True,
|
96
114
|
use_ref_model=True,
|
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
110
128
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
111
129
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
112
130
|
beta (float): Weight for the KL penalty
|
131
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
132
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
113
133
|
temperature (float): Temperature for the logits
|
114
134
|
compiled (bool): Whether to use torch compile
|
115
135
|
use_ref_model (bool): Whether to use a reference model
|
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
134
154
|
beta=beta,
|
135
155
|
epsilon_low=epsilon_low,
|
136
156
|
epsilon_high=epsilon_high,
|
157
|
+
loss_type=loss_type,
|
158
|
+
max_completion_length=max_completion_length,
|
137
159
|
temperature=temperature,
|
138
160
|
compiled=compiled,
|
139
161
|
use_ref_model=use_ref_model,
|
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
161
183
|
None, # grad_beta
|
162
184
|
None, # grad_epsilon_low
|
163
185
|
None, # grad_epsilon_high
|
186
|
+
None, # grad_loss_type (string, not differentiable)
|
187
|
+
None, # grad_max_completion_length (int, not differentiable)
|
164
188
|
None, # grad_temperature
|
165
189
|
None, # grad_compiled
|
166
190
|
None, # grad_use_ref_model
|
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
179
203
|
chunk_size: int = 1,
|
180
204
|
epsilon_low: float = 0.2,
|
181
205
|
epsilon_high: float = 0.2,
|
206
|
+
loss_type: str = "bnpo",
|
207
|
+
max_completion_length: int | None = None,
|
182
208
|
temperature: float = 1.0,
|
183
209
|
):
|
184
210
|
"""
|
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
189
215
|
chunk_size (int): Size of chunks for processing.
|
190
216
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
191
217
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
218
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
219
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
192
220
|
temperature (float): Temperature for the logits.
|
193
221
|
"""
|
194
222
|
super().__init__()
|
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
198
226
|
self.chunk_size = chunk_size
|
199
227
|
self.epsilon_low = epsilon_low
|
200
228
|
self.epsilon_high = epsilon_high
|
229
|
+
self.loss_type = loss_type
|
230
|
+
self.max_completion_length = max_completion_length
|
201
231
|
self.temperature = temperature
|
202
232
|
|
203
233
|
def forward(
|
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
229
259
|
self.beta,
|
230
260
|
self.epsilon_low,
|
231
261
|
self.epsilon_high,
|
262
|
+
self.loss_type,
|
263
|
+
self.max_completion_length,
|
232
264
|
self.temperature,
|
233
265
|
self.compiled,
|
234
266
|
self.use_ref_model,
|
@@ -222,7 +222,7 @@ def lce_forward(
|
|
222
222
|
lm_head_weight=self.lm_head.weight,
|
223
223
|
labels=labels,
|
224
224
|
hidden_size=self.config.hidden_size,
|
225
|
-
|
225
|
+
final_logit_softcapping=self.config.final_logit_softcapping,
|
226
226
|
**loss_kwargs,
|
227
227
|
)
|
228
228
|
|
@@ -112,7 +112,7 @@ def causal_forward(
|
|
112
112
|
lm_head_weight=self.lm_head.weight,
|
113
113
|
labels=labels,
|
114
114
|
hidden_size=self.config.hidden_size,
|
115
|
-
|
115
|
+
final_logit_softcapping=self.config.final_logit_softcapping,
|
116
116
|
**loss_kwargs,
|
117
117
|
)
|
118
118
|
|
@@ -1,14 +1,18 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
1
4
|
import torch.nn as nn
|
2
5
|
|
3
6
|
import liger_kernel.transformers.functional as F
|
4
7
|
|
5
8
|
|
6
9
|
def fixed_fused_linear_cross_entropy(
|
7
|
-
hidden_states,
|
8
|
-
lm_head_weight,
|
9
|
-
target,
|
10
|
-
num_items_in_batch: int = None,
|
10
|
+
hidden_states: torch.Tensor,
|
11
|
+
lm_head_weight: torch.Tensor,
|
12
|
+
target: torch.Tensor,
|
13
|
+
num_items_in_batch: Optional[int] = None,
|
11
14
|
ignore_index: int = -100,
|
15
|
+
final_logit_softcapping: Optional[float] = None,
|
12
16
|
**kwargs,
|
13
17
|
):
|
14
18
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
@@ -18,7 +22,7 @@ def fixed_fused_linear_cross_entropy(
|
|
18
22
|
target,
|
19
23
|
reduction=reduction,
|
20
24
|
ignore_index=ignore_index,
|
21
|
-
|
25
|
+
softcap=final_logit_softcapping,
|
22
26
|
)
|
23
27
|
if reduction == "sum":
|
24
28
|
loss = loss / num_items_in_batch
|
@@ -31,15 +35,17 @@ def LigerForCausalLMLoss(
|
|
31
35
|
lm_head_weight,
|
32
36
|
labels,
|
33
37
|
hidden_size: int,
|
34
|
-
num_items_in_batch: int = None,
|
38
|
+
num_items_in_batch: Optional[int] = None,
|
35
39
|
ignore_index: int = -100,
|
40
|
+
shift_labels: Optional[torch.Tensor] = None,
|
41
|
+
final_logit_softcapping: Optional[float] = None,
|
36
42
|
**kwargs,
|
37
43
|
):
|
38
44
|
# Skip upcast since intermediate values for the loss are all fp32 in kernel
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
45
|
+
if shift_labels is None:
|
46
|
+
# Shift so that token < n predict n
|
47
|
+
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
48
|
+
shift_labels = labels[..., 1:].contiguous()
|
43
49
|
|
44
50
|
# Flatten the tokens
|
45
51
|
hidden_states = hidden_states.view(-1, hidden_size)
|
@@ -52,6 +58,7 @@ def LigerForCausalLMLoss(
|
|
52
58
|
shift_labels,
|
53
59
|
num_items_in_batch,
|
54
60
|
ignore_index,
|
61
|
+
final_logit_softcapping,
|
55
62
|
**kwargs,
|
56
63
|
)
|
57
64
|
return loss
|
@@ -27,6 +27,8 @@ class TorchLMHeadGRPO(torch.nn.Module):
|
|
27
27
|
epsilon_high: float = 0.2,
|
28
28
|
temperature: float = 1.0,
|
29
29
|
use_ref_model: bool = True,
|
30
|
+
loss_type: str = "bnpo",
|
31
|
+
max_completion_length: int | None = None,
|
30
32
|
):
|
31
33
|
super().__init__()
|
32
34
|
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
@@ -36,6 +38,10 @@ class TorchLMHeadGRPO(torch.nn.Module):
|
|
36
38
|
self.epsilon_high = epsilon_high
|
37
39
|
self.temperature = temperature
|
38
40
|
self.use_ref_model = use_ref_model
|
41
|
+
self.loss_type = loss_type
|
42
|
+
self.max_completion_length = max_completion_length
|
43
|
+
if self.loss_type == "dr_grpo":
|
44
|
+
assert self.max_completion_length is not None, "max_completion_length must be provided for dr_grpo"
|
39
45
|
|
40
46
|
def forward(
|
41
47
|
self,
|
@@ -89,8 +95,15 @@ class TorchLMHeadGRPO(torch.nn.Module):
|
|
89
95
|
kl_div = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
|
90
96
|
per_token_loss = per_token_loss + self.beta * kl_div
|
91
97
|
|
92
|
-
# Apply masking and
|
93
|
-
|
98
|
+
# Apply masking and calculate loss based on loss_type
|
99
|
+
if self.loss_type == "grpo":
|
100
|
+
loss = ((per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)).mean()
|
101
|
+
elif self.loss_type == "bnpo":
|
102
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0)
|
103
|
+
elif self.loss_type == "dr_grpo":
|
104
|
+
loss = (per_token_loss * attention_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
|
105
|
+
else:
|
106
|
+
raise ValueError(f"Unknown loss type: {self.loss_type}")
|
94
107
|
|
95
108
|
# Compute metrics
|
96
109
|
metrics = []
|
@@ -115,6 +128,8 @@ class LigerLMHeadGRPO(torch.nn.Module):
|
|
115
128
|
epsilon_high: float = 0.2,
|
116
129
|
temperature: float = 1.0,
|
117
130
|
use_ref_model: bool = True,
|
131
|
+
loss_type: str = "bnpo",
|
132
|
+
max_completion_length: int | None = None,
|
118
133
|
):
|
119
134
|
super().__init__()
|
120
135
|
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
@@ -126,6 +141,8 @@ class LigerLMHeadGRPO(torch.nn.Module):
|
|
126
141
|
temperature=temperature,
|
127
142
|
use_ref_model=use_ref_model,
|
128
143
|
compiled=True,
|
144
|
+
loss_type=loss_type,
|
145
|
+
max_completion_length=max_completion_length,
|
129
146
|
)
|
130
147
|
|
131
148
|
def forward(
|
@@ -186,6 +203,7 @@ class LigerLMHeadGRPO(torch.nn.Module):
|
|
186
203
|
],
|
187
204
|
)
|
188
205
|
@pytest.mark.parametrize("old_per_token_logps", [True, False])
|
206
|
+
@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo"])
|
189
207
|
def test_correctness(
|
190
208
|
B,
|
191
209
|
T,
|
@@ -203,9 +221,12 @@ def test_correctness(
|
|
203
221
|
use_ref_per_token_logps,
|
204
222
|
use_ref_model,
|
205
223
|
old_per_token_logps,
|
224
|
+
loss_type,
|
206
225
|
):
|
207
226
|
# Reset torch compiler cache for each parameter of the test case
|
208
227
|
torch.compiler.reset()
|
228
|
+
max_completion_length = T if loss_type == "dr_grpo" else None
|
229
|
+
|
209
230
|
torch_lm_head_grpo = TorchLMHeadGRPO(
|
210
231
|
H=H,
|
211
232
|
V=V,
|
@@ -216,6 +237,8 @@ def test_correctness(
|
|
216
237
|
epsilon_high=epsilon_high,
|
217
238
|
temperature=temperature,
|
218
239
|
use_ref_model=use_ref_model,
|
240
|
+
loss_type=loss_type,
|
241
|
+
max_completion_length=max_completion_length,
|
219
242
|
)
|
220
243
|
liger_lm_head_grpo = LigerLMHeadGRPO(
|
221
244
|
H=H,
|
@@ -227,6 +250,8 @@ def test_correctness(
|
|
227
250
|
epsilon_high=epsilon_high,
|
228
251
|
temperature=temperature,
|
229
252
|
use_ref_model=use_ref_model,
|
253
|
+
loss_type=loss_type,
|
254
|
+
max_completion_length=max_completion_length,
|
230
255
|
)
|
231
256
|
|
232
257
|
# Initialize weights
|
@@ -319,7 +344,7 @@ def test_correctness(
|
|
319
344
|
loss1.backward()
|
320
345
|
loss2.backward()
|
321
346
|
|
322
|
-
# Check gradients match
|
347
|
+
# Check gradients match for loss_type
|
323
348
|
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
|
324
349
|
assert_verbose_allclose(
|
325
350
|
torch_lm_head_grpo.lin.weight.grad,
|
@@ -351,6 +376,7 @@ def test_correctness(
|
|
351
376
|
],
|
352
377
|
)
|
353
378
|
@pytest.mark.parametrize("bias", [True, False])
|
379
|
+
@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo"])
|
354
380
|
def test_functional_correctness(
|
355
381
|
B,
|
356
382
|
T,
|
@@ -361,9 +387,11 @@ def test_functional_correctness(
|
|
361
387
|
atol,
|
362
388
|
rtol,
|
363
389
|
bias,
|
390
|
+
loss_type,
|
364
391
|
):
|
365
392
|
# Reset torch compiler cache for each parameter of the test case
|
366
393
|
torch.compiler.reset()
|
394
|
+
max_completion_length = T if loss_type == "dr_grpo" else None
|
367
395
|
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
|
368
396
|
input1 = _input.detach().clone().requires_grad_(True)
|
369
397
|
input2 = _input.detach().clone().requires_grad_(True)
|
@@ -418,6 +446,8 @@ def test_functional_correctness(
|
|
418
446
|
0.04,
|
419
447
|
0.2,
|
420
448
|
0.2,
|
449
|
+
loss_type,
|
450
|
+
max_completion_length,
|
421
451
|
1.0,
|
422
452
|
True,
|
423
453
|
True,
|
@@ -439,6 +469,8 @@ def test_functional_correctness(
|
|
439
469
|
0.04,
|
440
470
|
0.2,
|
441
471
|
0.2,
|
472
|
+
loss_type,
|
473
|
+
max_completion_length,
|
442
474
|
1.0,
|
443
475
|
True,
|
444
476
|
True,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{liger_kernel_nightly-0.5.6.dev20250411201510 → liger_kernel_nightly-0.5.6.dev20250411224032}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|