liger-kernel-nightly 0.5.4.dev20250305231637__tar.gz → 0.5.4.dev20250307064336__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/cpo_loss.py +42 -3
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/dpo_loss.py +26 -1
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +23 -3
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/grpo_loss.py +33 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/jsd_loss.py +8 -1
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/kto_loss.py +29 -1
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/orpo_loss.py +33 -2
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/simpo_loss.py +44 -9
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/Makefile +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/setup.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/test/utils.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.4.
|
|
7
|
+
version = "0.5.4.dev20250307064336"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -53,7 +53,27 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
53
53
|
label_smoothing=0.0,
|
|
54
54
|
compute_nll_loss=True,
|
|
55
55
|
compiled=True,
|
|
56
|
+
average_log_prob=False,
|
|
57
|
+
chunk_size=1,
|
|
56
58
|
):
|
|
59
|
+
"""
|
|
60
|
+
Fused linear layer with CPO loss.
|
|
61
|
+
Args:
|
|
62
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
63
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
64
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
65
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
66
|
+
ignore_index (int): Index to ignore in loss computation
|
|
67
|
+
beta (float): Weight for the odds ratio loss
|
|
68
|
+
alpha (float): Weight for the alpha parameter
|
|
69
|
+
label_smoothing (float): Label smoothing factor
|
|
70
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
73
|
+
chunk_size (int): Size of chunks for processing.
|
|
74
|
+
Returns:
|
|
75
|
+
torch.Tensor: Computed loss
|
|
76
|
+
"""
|
|
57
77
|
return super().forward(
|
|
58
78
|
cls=cls,
|
|
59
79
|
ctx=ctx,
|
|
@@ -66,14 +86,15 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
66
86
|
beta=beta,
|
|
67
87
|
label_smoothing=label_smoothing,
|
|
68
88
|
compute_nll_loss=compute_nll_loss,
|
|
69
|
-
average_log_prob=
|
|
89
|
+
average_log_prob=average_log_prob,
|
|
70
90
|
compiled=compiled,
|
|
91
|
+
chunk_size=chunk_size,
|
|
71
92
|
)
|
|
72
93
|
|
|
73
94
|
@staticmethod
|
|
74
95
|
def backward(ctx, *grad_output):
|
|
75
96
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
76
|
-
return *grads, None, None, None, None, None, None
|
|
97
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
77
98
|
|
|
78
99
|
|
|
79
100
|
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
@@ -89,11 +110,19 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
89
110
|
label_smoothing: float = 0.0,
|
|
90
111
|
compute_nll_loss: bool = True,
|
|
91
112
|
compiled: bool = True,
|
|
113
|
+
average_log_prob: bool = False,
|
|
114
|
+
chunk_size: int = 1,
|
|
92
115
|
):
|
|
93
116
|
"""
|
|
94
117
|
Args:
|
|
95
118
|
ignore_index (int): Index to ignore in the loss.
|
|
96
119
|
beta (float): Weight for the odds ratio loss.
|
|
120
|
+
alpha (float): Weight for the alpha parameter.
|
|
121
|
+
label_smoothing (float): Label smoothing factor.
|
|
122
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
123
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
124
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
125
|
+
chunk_size (int): Size of chunks for processing.
|
|
97
126
|
"""
|
|
98
127
|
super().__init__()
|
|
99
128
|
self.ignore_index = ignore_index
|
|
@@ -102,8 +131,16 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
102
131
|
self.label_smoothing = label_smoothing
|
|
103
132
|
self.compute_nll_loss = compute_nll_loss
|
|
104
133
|
self.compiled = compiled
|
|
134
|
+
self.average_log_prob = average_log_prob
|
|
135
|
+
self.chunk_size = chunk_size
|
|
105
136
|
|
|
106
|
-
def forward(
|
|
137
|
+
def forward(
|
|
138
|
+
self,
|
|
139
|
+
lin_weight,
|
|
140
|
+
_input,
|
|
141
|
+
target,
|
|
142
|
+
bias=None,
|
|
143
|
+
):
|
|
107
144
|
return LigerFusedLinearCPOFunction.apply(
|
|
108
145
|
_input,
|
|
109
146
|
lin_weight,
|
|
@@ -115,4 +152,6 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
115
152
|
self.label_smoothing,
|
|
116
153
|
self.compute_nll_loss,
|
|
117
154
|
self.compiled,
|
|
155
|
+
self.average_log_prob,
|
|
156
|
+
self.chunk_size,
|
|
118
157
|
)
|
|
@@ -68,7 +68,27 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
68
|
compute_nll_loss=False,
|
|
69
69
|
compiled=True,
|
|
70
70
|
use_ref_model=True,
|
|
71
|
+
chunk_size=1,
|
|
71
72
|
):
|
|
73
|
+
"""
|
|
74
|
+
Fused linear layer with DPO loss.
|
|
75
|
+
Args:
|
|
76
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
77
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
78
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
79
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
80
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
81
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
82
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
83
|
+
ignore_index (int): Index to ignore in loss computation
|
|
84
|
+
beta (float): Weight for the odds ratio loss
|
|
85
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
86
|
+
compiled (bool): Whether to use torch compile
|
|
87
|
+
use_ref_model (bool): Whether to use a reference model
|
|
88
|
+
chunk_size (int): Size of chunks for processing.
|
|
89
|
+
Returns:
|
|
90
|
+
torch.Tensor: Computed loss
|
|
91
|
+
"""
|
|
72
92
|
return super().forward(
|
|
73
93
|
cls=cls,
|
|
74
94
|
ctx=ctx,
|
|
@@ -84,12 +104,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
84
104
|
ref_input=ref_input,
|
|
85
105
|
ref_weight=ref_weight,
|
|
86
106
|
ref_bias=ref_bias,
|
|
107
|
+
chunk_size=chunk_size,
|
|
87
108
|
)
|
|
88
109
|
|
|
89
110
|
@staticmethod
|
|
90
111
|
def backward(ctx, *grad_output):
|
|
91
112
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
92
|
-
return *grads, None, None, None, None, None, None, None, None
|
|
113
|
+
return *grads, None, None, None, None, None, None, None, None, None
|
|
93
114
|
|
|
94
115
|
|
|
95
116
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -104,6 +125,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
104
125
|
compute_nll_loss: bool = False,
|
|
105
126
|
compiled: bool = True,
|
|
106
127
|
use_ref_model: bool = True,
|
|
128
|
+
chunk_size: int = 1,
|
|
107
129
|
):
|
|
108
130
|
"""
|
|
109
131
|
Args:
|
|
@@ -112,6 +134,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
112
134
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
113
135
|
compiled (bool): Whether to use the torch compiled kernel.
|
|
114
136
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
137
|
+
chunk_size (int): Size of chunks for processing.
|
|
115
138
|
"""
|
|
116
139
|
super().__init__()
|
|
117
140
|
self.ignore_index = ignore_index
|
|
@@ -119,6 +142,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
119
142
|
self.compute_nll_loss = compute_nll_loss
|
|
120
143
|
self.compiled = compiled
|
|
121
144
|
self.use_ref_model = use_ref_model
|
|
145
|
+
self.chunk_size = chunk_size
|
|
122
146
|
|
|
123
147
|
def forward(
|
|
124
148
|
self,
|
|
@@ -143,4 +167,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
143
167
|
self.compute_nll_loss,
|
|
144
168
|
self.compiled,
|
|
145
169
|
self.use_ref_model,
|
|
170
|
+
self.chunk_size,
|
|
146
171
|
)
|
|
@@ -29,8 +29,27 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
|
29
29
|
ref_input=None,
|
|
30
30
|
ref_weight=None,
|
|
31
31
|
ref_bias=None,
|
|
32
|
+
chunk_size=1,
|
|
32
33
|
):
|
|
33
|
-
"""Chunked forward pass for RLHF loss computation.
|
|
34
|
+
"""Chunked forward pass for RLHF loss computation.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
cls: The class
|
|
38
|
+
ctx: Context for backward
|
|
39
|
+
_input: Input tensor
|
|
40
|
+
weight: Weight tensor
|
|
41
|
+
attention_mask: Attention mask tensor
|
|
42
|
+
rewards: Rewards tensor
|
|
43
|
+
bias: Bias tensor
|
|
44
|
+
num_generations: Number of generations per prompt
|
|
45
|
+
beta: Weight for the KL penalty
|
|
46
|
+
compiled: Whether to use torch compile
|
|
47
|
+
use_ref_model: Whether to use a reference model
|
|
48
|
+
ref_input: Reference model input tensor
|
|
49
|
+
ref_weight: Reference model weight tensor
|
|
50
|
+
ref_bias: Reference model bias tensor
|
|
51
|
+
chunk_size: Size of chunks for processing in other loss modules
|
|
52
|
+
"""
|
|
34
53
|
# Save for backward
|
|
35
54
|
ctx.beta = beta
|
|
36
55
|
ctx.rewards = rewards
|
|
@@ -106,7 +125,7 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
|
106
125
|
if compiled:
|
|
107
126
|
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
108
127
|
|
|
109
|
-
# Process input in chunks
|
|
128
|
+
# Process input in chunks based on num_generations
|
|
110
129
|
chunks = max(1, _input.shape[0] // num_generations)
|
|
111
130
|
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
112
131
|
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
@@ -210,11 +229,12 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
|
210
229
|
None, # grad_attention_mask
|
|
211
230
|
None, # grad_rewards
|
|
212
231
|
grad_bias,
|
|
213
|
-
None, #
|
|
232
|
+
None, # grad_num_generations
|
|
214
233
|
None, # grad_beta
|
|
215
234
|
None, # grad_compiled
|
|
216
235
|
None, # grad_use_ref_model
|
|
217
236
|
None, # grad_ref_input
|
|
218
237
|
None, # grad_ref_weight
|
|
219
238
|
None, # grad_ref_bias
|
|
239
|
+
None, # grad_chunk_size
|
|
220
240
|
)
|
|
@@ -79,7 +79,27 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
79
79
|
compiled=True,
|
|
80
80
|
use_ref_model=True,
|
|
81
81
|
num_generations=1,
|
|
82
|
+
chunk_size=1,
|
|
82
83
|
):
|
|
84
|
+
"""
|
|
85
|
+
Fused linear layer with GRPO loss.
|
|
86
|
+
Args:
|
|
87
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
88
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
89
|
+
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
|
90
|
+
rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,)
|
|
91
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
92
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
93
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
94
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
95
|
+
beta (float): Weight for the KL penalty
|
|
96
|
+
compiled (bool): Whether to use torch compile
|
|
97
|
+
use_ref_model (bool): Whether to use a reference model
|
|
98
|
+
num_generations (int): Number of generations per prompt
|
|
99
|
+
chunk_size (int): Size of chunks for processing.
|
|
100
|
+
Returns:
|
|
101
|
+
torch.Tensor: Computed loss
|
|
102
|
+
"""
|
|
83
103
|
return super().forward(
|
|
84
104
|
cls=cls,
|
|
85
105
|
ctx=ctx,
|
|
@@ -95,6 +115,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
95
115
|
compiled=compiled,
|
|
96
116
|
use_ref_model=use_ref_model,
|
|
97
117
|
num_generations=num_generations,
|
|
118
|
+
chunk_size=chunk_size,
|
|
98
119
|
)
|
|
99
120
|
|
|
100
121
|
@staticmethod
|
|
@@ -115,6 +136,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
115
136
|
None, # grad_compiled
|
|
116
137
|
None, # grad_use_ref_model
|
|
117
138
|
None, # grad_num_generations
|
|
139
|
+
None, # grad_chunk_size
|
|
118
140
|
)
|
|
119
141
|
|
|
120
142
|
|
|
@@ -127,12 +149,22 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
127
149
|
compiled: bool = True,
|
|
128
150
|
use_ref_model: bool = True,
|
|
129
151
|
num_generations: int = 1,
|
|
152
|
+
chunk_size: int = 1,
|
|
130
153
|
):
|
|
154
|
+
"""
|
|
155
|
+
Args:
|
|
156
|
+
beta (float): Weight for the KL penalty.
|
|
157
|
+
compiled (bool): Whether to use torch compile.
|
|
158
|
+
use_ref_model (bool): Whether to use a reference model.
|
|
159
|
+
num_generations (int): Number of generations per prompt.
|
|
160
|
+
chunk_size (int): Size of chunks for processing.
|
|
161
|
+
"""
|
|
131
162
|
super().__init__()
|
|
132
163
|
self.beta = beta
|
|
133
164
|
self.compiled = compiled
|
|
134
165
|
self.use_ref_model = use_ref_model
|
|
135
166
|
self.num_generations = num_generations
|
|
167
|
+
self.chunk_size = chunk_size
|
|
136
168
|
|
|
137
169
|
def forward(
|
|
138
170
|
self,
|
|
@@ -158,4 +190,5 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
158
190
|
self.compiled,
|
|
159
191
|
self.use_ref_model,
|
|
160
192
|
self.num_generations,
|
|
193
|
+
self.chunk_size,
|
|
161
194
|
)
|
|
@@ -47,6 +47,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
47
47
|
ignore_index: int = -100,
|
|
48
48
|
temperature: float = 1.0,
|
|
49
49
|
compiled: bool = True,
|
|
50
|
+
chunk_size: int = 1024,
|
|
50
51
|
):
|
|
51
52
|
"""
|
|
52
53
|
Fused linear layer with JSD distillation loss.
|
|
@@ -62,6 +63,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
62
63
|
ignore_index (int): Index to ignore in loss computation
|
|
63
64
|
temperature (float): Temperature for softening/sharpening distributions
|
|
64
65
|
compiled (bool): Whether to use torch compile
|
|
66
|
+
chunk_size (int): Size of chunks for processing.
|
|
65
67
|
Returns:
|
|
66
68
|
torch.Tensor: Computed loss
|
|
67
69
|
"""
|
|
@@ -75,7 +77,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
75
77
|
target=true_labels,
|
|
76
78
|
student_bias=student_bias,
|
|
77
79
|
teacher_bias=teacher_bias,
|
|
78
|
-
chunk_size=
|
|
80
|
+
chunk_size=chunk_size,
|
|
79
81
|
weight_hard_loss=weight_hard_loss,
|
|
80
82
|
weight_soft_loss=weight_soft_loss,
|
|
81
83
|
beta=beta,
|
|
@@ -97,6 +99,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
97
99
|
None, # ignore_index
|
|
98
100
|
None, # temperature
|
|
99
101
|
None, # compiled
|
|
102
|
+
None, # chunk_size
|
|
100
103
|
)
|
|
101
104
|
|
|
102
105
|
|
|
@@ -113,6 +116,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
113
116
|
ignore_index: int = -100,
|
|
114
117
|
temperature: float = 1.0,
|
|
115
118
|
compiled: bool = True,
|
|
119
|
+
chunk_size: int = 1024,
|
|
116
120
|
):
|
|
117
121
|
"""
|
|
118
122
|
Args:
|
|
@@ -122,6 +126,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
122
126
|
temperature (float): Temperature for softening distributions
|
|
123
127
|
compiled (bool): Whether to use torch compile
|
|
124
128
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
129
|
+
chunk_size (int): Size of chunks for processing.
|
|
125
130
|
"""
|
|
126
131
|
super().__init__()
|
|
127
132
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -131,6 +136,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
131
136
|
self.temperature = temperature
|
|
132
137
|
self.compiled = compiled
|
|
133
138
|
self.beta = beta
|
|
139
|
+
self.chunk_size = chunk_size
|
|
134
140
|
|
|
135
141
|
def forward(
|
|
136
142
|
self,
|
|
@@ -169,4 +175,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
169
175
|
self.ignore_index,
|
|
170
176
|
self.temperature,
|
|
171
177
|
self.compiled,
|
|
178
|
+
self.chunk_size,
|
|
172
179
|
)
|
|
@@ -86,7 +86,29 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
86
86
|
compiled=True,
|
|
87
87
|
use_ref_model=True,
|
|
88
88
|
average_log_prob=False,
|
|
89
|
+
chunk_size=1,
|
|
89
90
|
):
|
|
91
|
+
"""
|
|
92
|
+
Fused linear layer with KTO loss.
|
|
93
|
+
Args:
|
|
94
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
95
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
96
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
97
|
+
preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
|
|
98
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
99
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
100
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
101
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
102
|
+
kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
|
|
103
|
+
ignore_index (int): Index to ignore in loss computation
|
|
104
|
+
beta (float): Temperature parameter for the KTO loss
|
|
105
|
+
compiled (bool): Whether to use torch compile
|
|
106
|
+
use_ref_model (bool): Whether to use a reference model
|
|
107
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
108
|
+
chunk_size (int): Size of chunks for processing
|
|
109
|
+
Returns:
|
|
110
|
+
torch.Tensor: Computed loss
|
|
111
|
+
"""
|
|
90
112
|
return super().forward(
|
|
91
113
|
cls=cls,
|
|
92
114
|
ctx=ctx,
|
|
@@ -104,6 +126,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
104
126
|
ref_bias=ref_bias,
|
|
105
127
|
average_log_prob=average_log_prob,
|
|
106
128
|
kl=kl,
|
|
129
|
+
chunk_size=chunk_size,
|
|
107
130
|
)
|
|
108
131
|
|
|
109
132
|
@staticmethod
|
|
@@ -121,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
121
144
|
None,
|
|
122
145
|
None,
|
|
123
146
|
None,
|
|
147
|
+
None,
|
|
124
148
|
)
|
|
125
149
|
|
|
126
150
|
|
|
@@ -136,6 +160,7 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
136
160
|
compiled: bool = True,
|
|
137
161
|
use_ref_model: bool = False,
|
|
138
162
|
average_log_prob: bool = False,
|
|
163
|
+
chunk_size: int = 1,
|
|
139
164
|
):
|
|
140
165
|
"""
|
|
141
166
|
Args:
|
|
@@ -143,7 +168,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
143
168
|
beta (float): Temperature parameter for the KTO loss
|
|
144
169
|
compiled (bool): Whether to use compiled operations
|
|
145
170
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
146
|
-
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
171
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
172
|
+
chunk_size (int): Size of chunks for processing
|
|
147
173
|
"""
|
|
148
174
|
super().__init__()
|
|
149
175
|
self.ignore_index = ignore_index
|
|
@@ -151,6 +177,7 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
151
177
|
self.compiled = compiled
|
|
152
178
|
self.use_ref_model = use_ref_model
|
|
153
179
|
self.average_log_prob = average_log_prob
|
|
180
|
+
self.chunk_size = chunk_size
|
|
154
181
|
|
|
155
182
|
def forward(
|
|
156
183
|
self,
|
|
@@ -179,4 +206,5 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
179
206
|
self.compiled,
|
|
180
207
|
self.use_ref_model,
|
|
181
208
|
self.average_log_prob,
|
|
209
|
+
self.chunk_size,
|
|
182
210
|
)
|
|
@@ -55,7 +55,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
55
55
|
compute_nll_loss=True,
|
|
56
56
|
nll_target=None,
|
|
57
57
|
compiled=True,
|
|
58
|
+
chunk_size=1,
|
|
58
59
|
):
|
|
60
|
+
"""
|
|
61
|
+
Fused linear layer with ORPO loss.
|
|
62
|
+
Args:
|
|
63
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
64
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
65
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
66
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
67
|
+
ignore_index (int): Index to ignore in loss computation
|
|
68
|
+
beta (float): Weight for the odds ratio loss
|
|
69
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
70
|
+
nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
chunk_size (int): Size of chunks for processing
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Computed loss
|
|
75
|
+
"""
|
|
59
76
|
return super().forward(
|
|
60
77
|
cls=cls,
|
|
61
78
|
ctx=ctx,
|
|
@@ -68,12 +85,13 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
85
|
compute_nll_loss=compute_nll_loss,
|
|
69
86
|
nll_target=nll_target,
|
|
70
87
|
compiled=compiled,
|
|
88
|
+
chunk_size=chunk_size,
|
|
71
89
|
)
|
|
72
90
|
|
|
73
91
|
@staticmethod
|
|
74
92
|
def backward(ctx, *grad_output):
|
|
75
93
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
76
|
-
return *grads, None, None, None, None, None
|
|
94
|
+
return *grads, None, None, None, None, None, None
|
|
77
95
|
|
|
78
96
|
|
|
79
97
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -87,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
87
105
|
beta: float = 0.1,
|
|
88
106
|
compute_nll_loss: bool = True,
|
|
89
107
|
compiled: bool = True,
|
|
108
|
+
chunk_size: int = 1,
|
|
90
109
|
):
|
|
91
110
|
"""
|
|
92
111
|
Args:
|
|
93
112
|
ignore_index (int): Index to ignore in the loss.
|
|
94
113
|
beta (float): Weight for the odds ratio loss.
|
|
114
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
115
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
|
95
117
|
"""
|
|
96
118
|
super().__init__()
|
|
97
119
|
self.ignore_index = ignore_index
|
|
98
120
|
self.beta = beta
|
|
99
121
|
self.compute_nll_loss = compute_nll_loss
|
|
100
122
|
self.compiled = compiled
|
|
123
|
+
self.chunk_size = chunk_size
|
|
101
124
|
|
|
102
|
-
def forward(
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
lin_weight,
|
|
128
|
+
_input,
|
|
129
|
+
target,
|
|
130
|
+
bias=None,
|
|
131
|
+
nll_target=None,
|
|
132
|
+
):
|
|
103
133
|
return LigerFusedLinearORPOFunction.apply(
|
|
104
134
|
_input,
|
|
105
135
|
lin_weight,
|
|
@@ -110,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
110
140
|
self.compute_nll_loss,
|
|
111
141
|
nll_target,
|
|
112
142
|
self.compiled,
|
|
143
|
+
self.chunk_size,
|
|
113
144
|
)
|
|
@@ -62,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
62
62
|
compute_nll_loss=False,
|
|
63
63
|
compiled=True,
|
|
64
64
|
gamma=0.5,
|
|
65
|
+
chunk_size=1,
|
|
65
66
|
):
|
|
67
|
+
"""
|
|
68
|
+
Fused linear layer with SimPO loss.
|
|
69
|
+
Args:
|
|
70
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
71
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
72
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
73
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
74
|
+
ignore_index (int): Index to ignore in loss computation
|
|
75
|
+
beta (float): Weight for the odds ratio loss
|
|
76
|
+
alpha (float): Weight for the alpha parameter
|
|
77
|
+
label_smoothing (float): Label smoothing factor
|
|
78
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
79
|
+
compiled (bool): Whether to use torch compile
|
|
80
|
+
gamma (float): Weight for the gamma parameter
|
|
81
|
+
chunk_size (int): Size of chunks for processing
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Computed loss
|
|
84
|
+
"""
|
|
66
85
|
return super().forward(
|
|
67
|
-
cls,
|
|
68
|
-
ctx,
|
|
69
|
-
_input,
|
|
70
|
-
weight,
|
|
71
|
-
target,
|
|
72
|
-
bias,
|
|
73
|
-
compute_nll_loss=compute_nll_loss,
|
|
86
|
+
cls=cls,
|
|
87
|
+
ctx=ctx,
|
|
88
|
+
_input=_input,
|
|
89
|
+
weight=weight,
|
|
90
|
+
target=target,
|
|
91
|
+
bias=bias,
|
|
74
92
|
ignore_index=ignore_index,
|
|
75
93
|
alpha=alpha,
|
|
76
94
|
beta=beta,
|
|
77
95
|
label_smoothing=label_smoothing,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
78
97
|
compiled=compiled,
|
|
79
98
|
gamma=gamma,
|
|
99
|
+
chunk_size=chunk_size,
|
|
80
100
|
)
|
|
81
101
|
|
|
82
102
|
@staticmethod
|
|
83
103
|
def backward(ctx, *grad_output):
|
|
84
104
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
85
|
-
return *grads, None, None, None, None, None, None, None
|
|
105
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
86
106
|
|
|
87
107
|
|
|
88
108
|
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
@@ -99,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
99
119
|
compute_nll_loss: bool = True,
|
|
100
120
|
compiled: bool = True,
|
|
101
121
|
gamma: float = 0.5,
|
|
122
|
+
chunk_size: int = 1,
|
|
102
123
|
):
|
|
103
124
|
"""
|
|
104
125
|
Args:
|
|
105
126
|
ignore_index (int): Index to ignore in the loss.
|
|
106
127
|
beta (float): Weight for the odds ratio loss.
|
|
128
|
+
alpha (float): Weight for the alpha parameter.
|
|
129
|
+
label_smoothing (float): Label smoothing factor.
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
131
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
132
|
+
gamma (float): Weight for the gamma parameter.
|
|
133
|
+
chunk_size (int): Size of chunks for processing.
|
|
107
134
|
"""
|
|
108
135
|
super().__init__()
|
|
109
136
|
self.ignore_index = ignore_index
|
|
@@ -113,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
113
140
|
self.compute_nll_loss = compute_nll_loss
|
|
114
141
|
self.compiled = compiled
|
|
115
142
|
self.gamma = gamma
|
|
143
|
+
self.chunk_size = chunk_size
|
|
116
144
|
|
|
117
|
-
def forward(
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
lin_weight,
|
|
148
|
+
_input,
|
|
149
|
+
target,
|
|
150
|
+
bias=None,
|
|
151
|
+
):
|
|
118
152
|
return LigerFusedLinearSimPOFunction.apply(
|
|
119
153
|
_input,
|
|
120
154
|
lin_weight,
|
|
@@ -127,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
127
161
|
self.compute_nll_loss,
|
|
128
162
|
self.compiled,
|
|
129
163
|
self.gamma,
|
|
164
|
+
self.chunk_size,
|
|
130
165
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{liger_kernel_nightly-0.5.4.dev20250305231637 → liger_kernel_nightly-0.5.4.dev20250307064336}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|