liger-kernel-nightly 0.5.2.dev20250108102127__tar.gz → 0.5.2.dev20250109023714__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_orpo_loss.py +6 -4
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/fused_linear_preference.py +40 -12
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/orpo_loss.py +5 -2
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/trainer/orpo_trainer.py +16 -4
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_orpo_loss.py +10 -8
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/utils.py +5 -3
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/Makefile +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/Acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/CONTRIBUTING.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/License.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/setup.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/convergence/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -45,12 +45,13 @@ def bench_memory_fused_linear_orpo_loss(
|
|
45
45
|
|
46
46
|
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
|
47
47
|
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
|
48
|
+
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
|
48
49
|
|
49
50
|
def fwd():
|
50
51
|
if provider == "liger":
|
51
|
-
return liger_lm_head_orpo(_input, target)
|
52
|
+
return liger_lm_head_orpo(_input, target, nll_target)
|
52
53
|
elif provider == "huggingface":
|
53
|
-
return torch_lm_head_orpo(_input, target)
|
54
|
+
return torch_lm_head_orpo(_input, target, nll_target)
|
54
55
|
|
55
56
|
def full():
|
56
57
|
y = fwd()
|
@@ -91,12 +92,13 @@ def bench_speed_fused_linear_orpo_loss(
|
|
91
92
|
|
92
93
|
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
|
93
94
|
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
|
95
|
+
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
|
94
96
|
|
95
97
|
def fwd():
|
96
98
|
if provider == "liger":
|
97
|
-
return liger_lm_head_orpo(_input, target)
|
99
|
+
return liger_lm_head_orpo(_input, target, nll_target)
|
98
100
|
elif provider == "huggingface":
|
99
|
-
return torch_lm_head_orpo(_input, target)
|
101
|
+
return torch_lm_head_orpo(_input, target, nll_target)
|
100
102
|
|
101
103
|
if mode == "forward":
|
102
104
|
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.2.
|
7
|
+
version = "0.5.2.dev20250109023714"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -27,6 +27,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
27
27
|
alpha=1.0,
|
28
28
|
beta=0.1,
|
29
29
|
compute_nll_loss=True,
|
30
|
+
nll_target=None,
|
30
31
|
compiled=True,
|
31
32
|
use_ref_model=False,
|
32
33
|
ref_input=None,
|
@@ -58,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
58
59
|
alpha (float): Weight for the NLL loss.
|
59
60
|
beta (float): Weight for the preference loss.
|
60
61
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
62
|
+
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
|
61
63
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
62
64
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
63
65
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
@@ -96,11 +98,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
96
98
|
use_ref_model=use_ref_model,
|
97
99
|
ref_weight=ref_weight,
|
98
100
|
ref_bias=ref_bias,
|
101
|
+
full_nll_target=nll_target,
|
99
102
|
average_log_prob=average_log_prob,
|
100
103
|
**loss_kwargs,
|
101
104
|
)
|
102
105
|
|
103
|
-
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
|
106
|
+
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
|
104
107
|
"""
|
105
108
|
Fused forward and backward pass for a chunk of input and target.
|
106
109
|
"""
|
@@ -111,13 +114,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
111
114
|
target_chunk,
|
112
115
|
bias,
|
113
116
|
ref_input_chunk=ref_input_chunk,
|
117
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
114
118
|
)
|
115
119
|
else:
|
116
120
|
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
117
|
-
input_chunk,
|
121
|
+
input_chunk,
|
122
|
+
weight,
|
123
|
+
target_chunk,
|
124
|
+
ref_input_chunk=ref_input_chunk,
|
125
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
118
126
|
)
|
119
127
|
|
120
|
-
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
|
128
|
+
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
|
121
129
|
if bias is not None:
|
122
130
|
(
|
123
131
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
@@ -132,7 +140,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
132
140
|
*aux_outputs,
|
133
141
|
),
|
134
142
|
),
|
135
|
-
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
143
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
136
144
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
137
145
|
else:
|
138
146
|
(
|
@@ -148,7 +156,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
148
156
|
*aux_outputs,
|
149
157
|
),
|
150
158
|
),
|
151
|
-
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
159
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
152
160
|
|
153
161
|
# Accumulate gradients
|
154
162
|
grad_weight.add_(chunk_grad_weight)
|
@@ -191,6 +199,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
191
199
|
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
192
200
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
193
201
|
|
202
|
+
if nll_target is not None:
|
203
|
+
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
|
204
|
+
|
194
205
|
if use_ref_model:
|
195
206
|
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
196
207
|
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
@@ -202,6 +213,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
202
213
|
rejected_target_chunk,
|
203
214
|
ref_chosen_input_chunk,
|
204
215
|
ref_rejected_input_chunk,
|
216
|
+
chosen_nll_target_chunk,
|
205
217
|
) in zip(
|
206
218
|
_chosen_input_chunks,
|
207
219
|
_rejected_input_chunks,
|
@@ -209,6 +221,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
209
221
|
_rejected_target_chunks,
|
210
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
211
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
224
|
+
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
212
225
|
strict=False,
|
213
226
|
):
|
214
227
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
@@ -222,9 +235,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
222
235
|
torch._dynamo.mark_dynamic(target_chunk, 1)
|
223
236
|
torch._dynamo.mark_dynamic(target, 1)
|
224
237
|
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
238
|
+
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
|
225
239
|
|
226
240
|
# accumulate loss, gradients, and metrics
|
227
|
-
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
|
241
|
+
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
228
242
|
|
229
243
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
230
244
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
@@ -258,7 +272,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
258
272
|
grad_weight = grad_weight * grad_output[0][0]
|
259
273
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
260
274
|
|
261
|
-
return grad_input, grad_weight, None, grad_bias, None, None, None
|
275
|
+
return grad_input, grad_weight, None, grad_bias, None, None, None, None
|
262
276
|
|
263
277
|
@staticmethod
|
264
278
|
def chunk_forward(
|
@@ -268,6 +282,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
268
282
|
bias=None,
|
269
283
|
ignore_index=-100,
|
270
284
|
compute_nll_loss=True,
|
285
|
+
chosen_nll_target_chunk=None,
|
271
286
|
average_log_prob=True,
|
272
287
|
):
|
273
288
|
len_chosen_chunk = target_chunk.shape[0] // 2
|
@@ -278,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
278
293
|
|
279
294
|
chosen_nll_loss = 0.0
|
280
295
|
if compute_nll_loss:
|
296
|
+
nll_labels = (
|
297
|
+
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
|
298
|
+
)
|
281
299
|
chosen_nll_loss = F.nll_loss(
|
282
300
|
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
283
|
-
|
301
|
+
nll_labels.view(-1),
|
284
302
|
reduction="sum",
|
285
303
|
ignore_index=ignore_index,
|
286
304
|
)
|
@@ -324,6 +342,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
324
342
|
ref_input_chunk=None,
|
325
343
|
ref_weight=None,
|
326
344
|
ref_bias=None,
|
345
|
+
full_nll_target=None,
|
346
|
+
chosen_nll_target_chunk=None,
|
327
347
|
average_log_prob=True,
|
328
348
|
**loss_kwargs,
|
329
349
|
):
|
@@ -343,6 +363,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
343
363
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
344
364
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
345
365
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
366
|
+
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
|
367
|
+
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
|
346
368
|
average_log_prob (bool): Whether to average log probabilities or the sum.
|
347
369
|
loss_kwargs (dict): Additional arguments for the loss function.
|
348
370
|
"""
|
@@ -359,9 +381,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
359
381
|
bias=bias,
|
360
382
|
ignore_index=ignore_index,
|
361
383
|
compute_nll_loss=compute_nll_loss,
|
384
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
362
385
|
average_log_prob=average_log_prob,
|
363
386
|
)
|
364
|
-
|
387
|
+
if full_nll_target is not None:
|
388
|
+
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
|
389
|
+
else:
|
390
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
391
|
+
|
365
392
|
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
366
393
|
rejected_logits_mean = rejected_logits.sum() / (
|
367
394
|
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
@@ -372,9 +399,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
372
399
|
(
|
373
400
|
ref_chosen_logps,
|
374
401
|
ref_rejected_logps,
|
375
|
-
|
376
|
-
|
377
|
-
|
402
|
+
_,
|
403
|
+
_,
|
404
|
+
_,
|
378
405
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
379
406
|
ref_input_chunk,
|
380
407
|
ref_weight,
|
@@ -382,6 +409,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
382
409
|
ref_bias,
|
383
410
|
ignore_index=ignore_index,
|
384
411
|
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
412
|
+
chosen_nll_target_chunk=None,
|
385
413
|
average_log_prob=average_log_prob,
|
386
414
|
)
|
387
415
|
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
@@ -52,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
52
52
|
ignore_index=-100,
|
53
53
|
beta=0.1,
|
54
54
|
compute_nll_loss=True,
|
55
|
+
nll_target=None,
|
55
56
|
compiled=True,
|
56
57
|
):
|
57
58
|
return LigerFusedLinearPreferenceBase.forward(
|
@@ -64,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
64
65
|
ignore_index=ignore_index,
|
65
66
|
beta=beta,
|
66
67
|
compute_nll_loss=compute_nll_loss,
|
68
|
+
nll_target=nll_target,
|
67
69
|
compiled=compiled,
|
68
70
|
)
|
69
71
|
|
70
72
|
@staticmethod
|
71
73
|
def backward(ctx, *grad_output):
|
72
74
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
73
|
-
return *grads, None, None, None, None
|
75
|
+
return *grads, None, None, None, None, None
|
74
76
|
|
75
77
|
|
76
78
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
@@ -96,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
96
98
|
self.compute_nll_loss = compute_nll_loss
|
97
99
|
self.compiled = compiled
|
98
100
|
|
99
|
-
def forward(self, lin_weight, _input, target, bias=None):
|
101
|
+
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
|
100
102
|
return LigerFusedLinearORPOFunction.apply(
|
101
103
|
_input,
|
102
104
|
lin_weight,
|
@@ -105,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
105
107
|
self.ignore_index,
|
106
108
|
self.beta,
|
107
109
|
self.compute_nll_loss,
|
110
|
+
nll_target,
|
108
111
|
self.compiled,
|
109
112
|
)
|
@@ -93,6 +93,13 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
93
93
|
if self.aux_loss_enabled:
|
94
94
|
model_kwargs["output_router_logits"] = True
|
95
95
|
|
96
|
+
if self.is_encoder_decoder:
|
97
|
+
labels = concatenated_batch["concatenated_labels"].clone()
|
98
|
+
else:
|
99
|
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
100
|
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
101
|
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
102
|
+
|
96
103
|
if isinstance(model, FullyShardedDataParallel):
|
97
104
|
outputs = _FSDPForwardRedirection()(
|
98
105
|
model,
|
@@ -114,15 +121,20 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
114
121
|
|
115
122
|
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
116
123
|
|
117
|
-
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
118
|
-
return orpo_loss_fn(
|
124
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
|
125
|
+
return orpo_loss_fn(
|
126
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
|
127
|
+
)
|
119
128
|
|
120
129
|
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
121
130
|
model,
|
122
131
|
orpo_partial,
|
123
132
|
model.lm_head,
|
124
|
-
outputs.last_hidden_state,
|
125
|
-
concatenated_batch["concatenated_labels"]
|
133
|
+
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
|
134
|
+
concatenated_batch["concatenated_labels"][:, 1:]
|
135
|
+
if not self.is_encoder_decoder
|
136
|
+
else concatenated_batch["concatenated_labels"],
|
137
|
+
labels[:, 1:] if not self.is_encoder_decoder else labels,
|
126
138
|
)
|
127
139
|
# if aux_loss_enabled, add the aux_loss to the orpo_loss
|
128
140
|
if self.aux_loss_enabled:
|
@@ -86,8 +86,8 @@ class TorchLMHeadORPO(torch.nn.Module):
|
|
86
86
|
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
87
87
|
self.orpo_loss = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics
|
88
88
|
|
89
|
-
def forward(self, x, y):
|
90
|
-
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias)
|
89
|
+
def forward(self, x, y, nll_target=None):
|
90
|
+
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target)
|
91
91
|
|
92
92
|
|
93
93
|
class LigerLMHeadORPO(torch.nn.Module):
|
@@ -104,8 +104,8 @@ class LigerLMHeadORPO(torch.nn.Module):
|
|
104
104
|
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
105
105
|
self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta)
|
106
106
|
|
107
|
-
def forward(self, x, y):
|
108
|
-
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias)
|
107
|
+
def forward(self, x, y, nll_target=None):
|
108
|
+
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target)
|
109
109
|
|
110
110
|
|
111
111
|
@pytest.mark.parametrize(
|
@@ -164,13 +164,15 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index,
|
|
164
164
|
device=device,
|
165
165
|
dtype=torch.long,
|
166
166
|
)
|
167
|
+
nll_target = torch.randint(0, V, (B, T), device=device, dtype=torch.long)
|
168
|
+
|
167
169
|
# Assign some random number of elements as ignore_index
|
168
170
|
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
|
169
171
|
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
|
170
172
|
target.view(-1)[indices_to_assign] = ignore_index
|
171
173
|
|
172
|
-
loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target)
|
173
|
-
loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target)
|
174
|
+
loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target, nll_target)
|
175
|
+
loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target, nll_target)
|
174
176
|
|
175
177
|
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
|
176
178
|
|
@@ -244,8 +246,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias):
|
|
244
246
|
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
|
245
247
|
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None
|
246
248
|
|
247
|
-
loss1,
|
248
|
-
loss2,
|
249
|
+
loss1, _ = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1)
|
250
|
+
loss2, _ = liger_fused_linear_orpo(input2, weight2, target, bias2)
|
249
251
|
|
250
252
|
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
|
251
253
|
|
@@ -406,8 +406,9 @@ class HFAlignmentLoss:
|
|
406
406
|
_input: torch.FloatTensor,
|
407
407
|
weight: torch.FloatTensor,
|
408
408
|
target: torch.LongTensor,
|
409
|
-
bias: torch.FloatTensor = None,
|
409
|
+
bias: torch.FloatTensor | None = None,
|
410
410
|
average_log_prob: bool = True,
|
411
|
+
nll_target: torch.LongTensor | None = None,
|
411
412
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
412
413
|
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
413
414
|
|
@@ -430,7 +431,7 @@ class HFAlignmentLoss:
|
|
430
431
|
loss = loss_fct(logits, labels)
|
431
432
|
return loss
|
432
433
|
|
433
|
-
labels = target
|
434
|
+
labels = nll_target if nll_target is not None else target
|
434
435
|
chosen_nll_loss = torch.tensor(0.0, device=all_logits.device)
|
435
436
|
if self.compute_nll_loss:
|
436
437
|
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
@@ -465,10 +466,11 @@ class HFAlignmentLoss:
|
|
465
466
|
ref_weight: torch.FloatTensor = None,
|
466
467
|
ref_bias: torch.FloatTensor = None,
|
467
468
|
average_log_prob: bool = True,
|
469
|
+
nll_target: torch.LongTensor = None,
|
468
470
|
):
|
469
471
|
"""Compute the loss metrics for the given batch of inputs for train or test."""
|
470
472
|
|
471
|
-
forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob)
|
473
|
+
forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, nll_target)
|
472
474
|
(
|
473
475
|
policy_chosen_logps,
|
474
476
|
policy_rejected_logps,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{liger_kernel_nightly-0.5.2.dev20250108102127 → liger_kernel_nightly-0.5.2.dev20250109023714}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|