liger-kernel-nightly 0.5.4.dev20250227064037__tar.gz → 0.5.4.dev20250305025024__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/workflows/amd-ci.yml +4 -1
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/data/all_benchmark_data.csv +30 -31
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_distill_jsd_loss.py +2 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_kto_loss.py +4 -4
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/cpo_loss.py +9 -8
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/dpo_loss.py +4 -3
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +10 -3
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/grpo_loss.py +4 -3
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/jsd_loss.py +24 -6
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/kto_loss.py +22 -12
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/orpo_loss.py +4 -3
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/simpo_loss.py +4 -3
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_jsd_loss.py +49 -10
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_kto_loss.py +85 -8
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/utils.py +7 -1
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/Makefile +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/setup.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250227064037 → liger_kernel_nightly-0.5.4.dev20250305025024}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -47,6 +47,9 @@ jobs:
|
|
|
47
47
|
tests:
|
|
48
48
|
runs-on: linux-mi300-gpu-1
|
|
49
49
|
needs: [checkstyle]
|
|
50
|
+
strategy:
|
|
51
|
+
matrix:
|
|
52
|
+
rocm_version: ['6.2', '6.3']
|
|
50
53
|
|
|
51
54
|
steps:
|
|
52
55
|
- name: Checkout code
|
|
@@ -60,7 +63,7 @@ jobs:
|
|
|
60
63
|
- name: Setup Dependencies
|
|
61
64
|
run: |
|
|
62
65
|
python -m pip install --upgrade pip
|
|
63
|
-
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/
|
|
66
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}
|
|
64
67
|
|
|
65
68
|
- name: List Python Environments
|
|
66
69
|
run: python -m pip list
|
|
@@ -751,36 +751,6 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
|
|
|
751
751
|
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
|
|
752
752
|
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
|
|
753
753
|
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
|
|
754
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,7.841599941253662,7.801983833312988,7.849664211273193,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
755
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,15.568096160888672,15.555737495422363,16.054176330566406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
756
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,31.145376205444336,30.750951766967773,31.5398006439209,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
757
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,61.49708938598633,61.49708938598633,61.49708938598633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
758
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,122.01449584960938,122.01449584960938,122.01449584960938,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
759
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.892335891723633,7.8687615394592285,8.03729248046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
760
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,14.16302490234375,13.813311576843262,15.860223770141602,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
761
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,25.56470489501953,25.564167022705078,25.641658782958984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
762
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,53.0928955078125,53.0928955078125,53.0928955078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
763
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,108.76080322265625,108.76080322265625,108.76080322265625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
764
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,8.662687301635742,8.488287925720215,9.611334800720215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
765
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,18.40096092224121,17.99224281311035,18.57883644104004,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
766
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.09159851074219,31.708070755004883,32.475128173828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
767
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,69.30239868164062,69.30239868164062,69.30239868164062,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
768
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,124.2437744140625,124.2437744140625,124.2437744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
769
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,11.449472427368164,11.407564163208008,11.773555755615234,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
770
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,20.871471405029297,20.862951278686523,20.879276275634766,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
771
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,41.16409683227539,40.760780334472656,41.567413330078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
772
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,77.720703125,77.720703125,77.720703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
773
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,156.25794982910156,156.25794982910156,156.25794982910156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
774
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2027.48583984375,2027.48583984375,2027.48583984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
775
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,2789.736328125,2789.736328125,2789.736328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
776
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,2801.751953125,2801.751953125,2801.751953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
777
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,2825.783203125,2825.783203125,2825.783203125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
778
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,2873.845703125,2873.845703125,2873.845703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
779
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,3786.7373046875,3786.7373046875,3786.7373046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
780
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.25390625,5544.25390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
781
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
782
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
783
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
784
754
|
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
|
785
755
|
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
|
786
756
|
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
|
@@ -805,4 +775,33 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
|
|
|
805
775
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
806
776
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
807
777
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
808
|
-
|
|
778
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,3.9951679706573486,3.991487979888916,4.002252578735352,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
779
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,7.8037919998168945,7.788575649261475,7.808595180511475,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
780
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,15.43172836303711,15.430015563964844,15.4335355758667,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
781
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,30.66864013671875,30.66431999206543,30.670501708984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
782
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,61.1163215637207,61.1163215637207,61.1163215637207,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
783
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,3.8766400814056396,3.8680384159088135,3.8897151947021484,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
784
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,7.213727951049805,7.206470489501953,7.229574680328369,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
785
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,13.828800201416016,13.810944557189941,13.834943771362305,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
786
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,27.0930233001709,27.08517074584961,27.09713363647461,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
787
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,54.13715362548828,54.13715362548828,54.13715362548828,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
788
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,4.782928466796875,4.677459239959717,5.3430914878845215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
789
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,8.517248153686523,8.481344223022461,8.561504364013672,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
790
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,16.547504425048828,16.513471603393555,16.678144454956055,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
791
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,31.891263961791992,31.819705963134766,32.274131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
792
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,62.953758239746094,62.953758239746094,62.953758239746094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
793
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,6.201632022857666,6.163315296173096,6.314668655395508,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
794
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,11.156224250793457,11.142304420471191,11.207296371459961,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
795
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,21.249855041503906,21.231891632080078,21.264543533325195,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
796
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,41.55686569213867,41.536956787109375,41.57677459716797,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
797
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,81.56924438476562,81.56924438476562,81.56924438476562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
798
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2585.73876953125,2585.73876953125,2585.73876953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
799
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3348.9892578125,3348.9892578125,3348.9892578125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
800
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,3361.0048828125,3361.0048828125,3361.0048828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
801
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,3385.0361328125,3385.0361328125,3385.0361328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
802
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,3433.0986328125,3433.0986328125,3433.0986328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
803
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4341.74951171875,4341.74951171875,4341.74951171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
804
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.26513671875,6099.26513671875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
805
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
806
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
807
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
@@ -149,7 +149,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
149
149
|
y=target,
|
|
150
150
|
preference_labels=preference_labels,
|
|
151
151
|
kl=kl,
|
|
152
|
-
)
|
|
152
|
+
)[0]
|
|
153
153
|
elif provider == "huggingface":
|
|
154
154
|
return torch_kto_loss(
|
|
155
155
|
x=_input,
|
|
@@ -157,7 +157,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
157
157
|
y=target,
|
|
158
158
|
preference_labels=preference_labels,
|
|
159
159
|
kl=kl,
|
|
160
|
-
)
|
|
160
|
+
)[0]
|
|
161
161
|
|
|
162
162
|
def full():
|
|
163
163
|
y = fwd()
|
|
@@ -230,7 +230,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
230
230
|
y=target,
|
|
231
231
|
preference_labels=preference_labels,
|
|
232
232
|
kl=kl,
|
|
233
|
-
)
|
|
233
|
+
)[0]
|
|
234
234
|
elif provider == "huggingface":
|
|
235
235
|
return torch_kto_loss(
|
|
236
236
|
x=_input,
|
|
@@ -238,7 +238,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
238
238
|
y=target,
|
|
239
239
|
preference_labels=preference_labels,
|
|
240
240
|
kl=kl,
|
|
241
|
-
)
|
|
241
|
+
)[0]
|
|
242
242
|
|
|
243
243
|
if mode == "forward":
|
|
244
244
|
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.4.
|
|
7
|
+
version = "0.5.4.dev20250305025024"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -39,8 +39,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
39
39
|
|
|
40
40
|
return loss, chosen_rewards, rejected_rewards
|
|
41
41
|
|
|
42
|
-
@
|
|
42
|
+
@classmethod
|
|
43
43
|
def forward(
|
|
44
|
+
cls,
|
|
44
45
|
ctx,
|
|
45
46
|
_input,
|
|
46
47
|
weight,
|
|
@@ -53,13 +54,13 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
53
54
|
compute_nll_loss=True,
|
|
54
55
|
compiled=True,
|
|
55
56
|
):
|
|
56
|
-
return
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
57
|
+
return super().forward(
|
|
58
|
+
cls=cls,
|
|
59
|
+
ctx=ctx,
|
|
60
|
+
_input=_input,
|
|
61
|
+
weight=weight,
|
|
62
|
+
target=target,
|
|
63
|
+
bias=bias,
|
|
63
64
|
ignore_index=ignore_index,
|
|
64
65
|
alpha=alpha,
|
|
65
66
|
beta=beta,
|
|
@@ -52,8 +52,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
52
52
|
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
53
53
|
return loss, chosen_rewards, rejected_rewards
|
|
54
54
|
|
|
55
|
-
@
|
|
55
|
+
@classmethod
|
|
56
56
|
def forward(
|
|
57
|
+
cls,
|
|
57
58
|
ctx,
|
|
58
59
|
_input,
|
|
59
60
|
weight,
|
|
@@ -68,13 +69,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
69
|
compiled=True,
|
|
69
70
|
use_ref_model=True,
|
|
70
71
|
):
|
|
71
|
-
return
|
|
72
|
+
return super().forward(
|
|
73
|
+
cls=cls,
|
|
72
74
|
ctx=ctx,
|
|
73
75
|
_input=_input,
|
|
74
76
|
weight=weight,
|
|
75
77
|
target=target,
|
|
76
78
|
bias=bias,
|
|
77
|
-
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
|
|
78
79
|
ignore_index=ignore_index,
|
|
79
80
|
beta=beta,
|
|
80
81
|
compute_nll_loss=compute_nll_loss,
|
|
@@ -125,6 +125,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
125
125
|
|
|
126
126
|
@staticmethod
|
|
127
127
|
def forward(
|
|
128
|
+
cls,
|
|
128
129
|
ctx,
|
|
129
130
|
student_input,
|
|
130
131
|
student_weight,
|
|
@@ -133,7 +134,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
133
134
|
target,
|
|
134
135
|
student_bias=None,
|
|
135
136
|
teacher_bias=None,
|
|
136
|
-
loss_fn=None,
|
|
137
137
|
chunk_size=1024,
|
|
138
138
|
ignore_index=-100,
|
|
139
139
|
weight_hard_loss=0.5,
|
|
@@ -175,7 +175,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
175
175
|
|
|
176
176
|
loss_func_to_call = partial(
|
|
177
177
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
178
|
-
distillation_loss_fn=
|
|
178
|
+
distillation_loss_fn=cls.distillation_loss_fn,
|
|
179
179
|
full_target=target,
|
|
180
180
|
ignore_index=ignore_index,
|
|
181
181
|
weight_hard_loss=weight_hard_loss,
|
|
@@ -263,4 +263,4 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
263
263
|
grad_weight = grad_weight * grad_output
|
|
264
264
|
grad_bias = grad_bias * grad_output if grad_bias is not None else None
|
|
265
265
|
|
|
266
|
-
return grad_input, grad_weight, None, grad_bias
|
|
266
|
+
return grad_input, grad_weight, None, None, None, grad_bias
|
|
@@ -16,12 +16,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
16
16
|
|
|
17
17
|
@staticmethod
|
|
18
18
|
def forward(
|
|
19
|
+
cls,
|
|
19
20
|
ctx,
|
|
20
21
|
_input,
|
|
21
22
|
weight,
|
|
22
23
|
target,
|
|
23
24
|
bias=None,
|
|
24
|
-
loss_fn=None,
|
|
25
25
|
chunk_size=1,
|
|
26
26
|
ignore_index=-100,
|
|
27
27
|
alpha=1.0,
|
|
@@ -89,7 +89,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
89
89
|
|
|
90
90
|
compute_loss = partial(
|
|
91
91
|
LigerFusedLinearPreferenceBase._compute_loss,
|
|
92
|
-
preference_loss_fn=
|
|
92
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
93
93
|
ignore_index=ignore_index,
|
|
94
94
|
alpha=alpha,
|
|
95
95
|
beta=beta,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
1
2
|
from functools import partial
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -5,15 +6,22 @@ import torch.nn.functional as F
|
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def rlhf_loss_fn(*args, **kwargs):
|
|
11
|
+
"""
|
|
12
|
+
To be extended by subclasses.
|
|
13
|
+
"""
|
|
14
|
+
raise NotImplementedError("RLHF loss function must be implemented.")
|
|
15
|
+
|
|
8
16
|
@staticmethod
|
|
9
17
|
def forward(
|
|
18
|
+
cls,
|
|
10
19
|
ctx,
|
|
11
20
|
_input,
|
|
12
21
|
weight,
|
|
13
22
|
attention_mask,
|
|
14
23
|
rewards,
|
|
15
24
|
bias=None,
|
|
16
|
-
loss_fn=None,
|
|
17
25
|
num_generations=4,
|
|
18
26
|
beta=0.1,
|
|
19
27
|
compiled=True,
|
|
@@ -41,7 +49,7 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
|
41
49
|
use_ref_model=use_ref_model,
|
|
42
50
|
ref_weight=ref_weight,
|
|
43
51
|
ref_bias=ref_bias,
|
|
44
|
-
rlhf_loss_fn=
|
|
52
|
+
rlhf_loss_fn=cls.rlhf_loss_fn,
|
|
45
53
|
)
|
|
46
54
|
|
|
47
55
|
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
|
|
@@ -202,7 +210,6 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
|
202
210
|
None, # grad_attention_mask
|
|
203
211
|
None, # grad_rewards
|
|
204
212
|
grad_bias,
|
|
205
|
-
None, # grad_loss_fn
|
|
206
213
|
None, # grad_chunk_size
|
|
207
214
|
None, # grad_beta
|
|
208
215
|
None, # grad_compiled
|
|
@@ -16,13 +16,13 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
16
16
|
|
|
17
17
|
@staticmethod
|
|
18
18
|
def forward(
|
|
19
|
+
cls,
|
|
19
20
|
ctx,
|
|
20
21
|
_input,
|
|
21
22
|
weight,
|
|
22
23
|
target,
|
|
23
24
|
preference_labels,
|
|
24
25
|
bias=None,
|
|
25
|
-
loss_fn=None,
|
|
26
26
|
chunk_size=1,
|
|
27
27
|
ignore_index=-100,
|
|
28
28
|
compiled=True,
|
|
@@ -30,6 +30,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
30
30
|
ref_input=None,
|
|
31
31
|
ref_weight=None,
|
|
32
32
|
ref_bias=None,
|
|
33
|
+
average_log_prob=False,
|
|
33
34
|
**loss_kwargs,
|
|
34
35
|
):
|
|
35
36
|
"""
|
|
@@ -59,6 +60,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
59
60
|
Shape: (batch_size,).
|
|
60
61
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
61
62
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
63
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
62
64
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
63
65
|
"""
|
|
64
66
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
@@ -72,14 +74,22 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
72
74
|
# Loss to be accumulated
|
|
73
75
|
loss_acc = torch.zeros((), device=_input.device)
|
|
74
76
|
|
|
77
|
+
# Metrics to be recorded
|
|
78
|
+
chosen_logps_sum = torch.zeros((), device=_input.device)
|
|
79
|
+
rejected_logps_sum = torch.zeros((), device=_input.device)
|
|
80
|
+
chosen_logits_sum = torch.zeros((), device=_input.device)
|
|
81
|
+
rejected_logits_sum = torch.zeros((), device=_input.device)
|
|
82
|
+
aggregated_aux_outputs = []
|
|
83
|
+
|
|
75
84
|
compute_loss = partial(
|
|
76
85
|
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
|
77
|
-
preference_loss_fn=
|
|
86
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
78
87
|
full_target=target,
|
|
79
88
|
ignore_index=ignore_index,
|
|
80
89
|
use_ref_model=use_ref_model,
|
|
81
90
|
ref_weight=ref_weight,
|
|
82
91
|
ref_bias=ref_bias,
|
|
92
|
+
average_log_prob=average_log_prob,
|
|
83
93
|
**loss_kwargs,
|
|
84
94
|
)
|
|
85
95
|
|
|
@@ -88,7 +98,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
88
98
|
Fused forward and backward pass for a chunk of input and target.
|
|
89
99
|
"""
|
|
90
100
|
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
|
91
|
-
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=
|
|
101
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
92
102
|
input_chunk,
|
|
93
103
|
weight,
|
|
94
104
|
target_chunk,
|
|
@@ -103,9 +113,19 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
103
113
|
preference_labels_chunk=None,
|
|
104
114
|
ref_input_chunk=None,
|
|
105
115
|
):
|
|
106
|
-
(
|
|
107
|
-
|
|
108
|
-
|
|
116
|
+
(
|
|
117
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
|
|
118
|
+
(
|
|
119
|
+
chunk_loss,
|
|
120
|
+
(
|
|
121
|
+
chunk_chosen_logps_sum,
|
|
122
|
+
chunk_rejected_logps_sum,
|
|
123
|
+
chunk_chosen_logits_sum,
|
|
124
|
+
chunk_rejected_logits_sum,
|
|
125
|
+
*aux_outputs,
|
|
126
|
+
),
|
|
127
|
+
),
|
|
128
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
109
129
|
if bias is not None:
|
|
110
130
|
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
|
111
131
|
|
|
@@ -116,6 +136,23 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
116
136
|
# Accumulate loss
|
|
117
137
|
loss_acc.add_(chunk_loss)
|
|
118
138
|
|
|
139
|
+
# Accumulate metrics
|
|
140
|
+
chosen_logps_sum.add_(chunk_chosen_logps_sum)
|
|
141
|
+
rejected_logps_sum.add_(chunk_rejected_logps_sum)
|
|
142
|
+
chosen_logits_sum.add_(chunk_chosen_logits_sum)
|
|
143
|
+
rejected_logits_sum.add_(chunk_rejected_logits_sum)
|
|
144
|
+
|
|
145
|
+
# aux_outputs
|
|
146
|
+
# Initialize storage for aux_outputs
|
|
147
|
+
if len(aggregated_aux_outputs) == 0:
|
|
148
|
+
for aux in aux_outputs:
|
|
149
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
|
150
|
+
|
|
151
|
+
# Process each aux_output
|
|
152
|
+
for i, aux in enumerate(aux_outputs):
|
|
153
|
+
if aux.ndim == 0:
|
|
154
|
+
aggregated_aux_outputs[i].add_(aux)
|
|
155
|
+
|
|
119
156
|
if compiled:
|
|
120
157
|
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
121
158
|
|
|
@@ -151,12 +188,25 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
151
188
|
# accumulate loss, gradients, and metrics
|
|
152
189
|
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
153
190
|
|
|
191
|
+
# Aggregate aux outputs lists into tensors
|
|
192
|
+
for i, aux in enumerate(aggregated_aux_outputs):
|
|
193
|
+
if isinstance(aux, list):
|
|
194
|
+
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
|
|
195
|
+
|
|
154
196
|
ctx.save_for_backward(
|
|
155
197
|
torch.cat(grad_inputs, dim=0),
|
|
156
198
|
grad_weight,
|
|
157
199
|
grad_bias,
|
|
158
200
|
)
|
|
159
|
-
|
|
201
|
+
|
|
202
|
+
return_vars = (
|
|
203
|
+
chosen_logps_sum,
|
|
204
|
+
rejected_logps_sum,
|
|
205
|
+
chosen_logits_sum,
|
|
206
|
+
rejected_logits_sum,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return loss_acc, (*return_vars, *aggregated_aux_outputs)
|
|
160
210
|
|
|
161
211
|
@staticmethod
|
|
162
212
|
def backward(ctx, *grad_output):
|
|
@@ -173,21 +223,37 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
173
223
|
input_chunk,
|
|
174
224
|
weight,
|
|
175
225
|
target_chunk,
|
|
226
|
+
preference_labels_chunk,
|
|
176
227
|
bias=None,
|
|
177
228
|
ignore_index=-100,
|
|
229
|
+
average_log_prob=False,
|
|
178
230
|
):
|
|
179
231
|
logits_chunk = input_chunk @ weight.t()
|
|
180
232
|
if bias is not None:
|
|
181
233
|
logits_chunk = logits_chunk + bias
|
|
182
234
|
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
183
|
-
|
|
184
235
|
loss_mask_chunk = target_chunk != ignore_index
|
|
185
236
|
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
|
186
237
|
|
|
187
238
|
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
239
|
+
if average_log_prob:
|
|
240
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
|
241
|
+
else:
|
|
242
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
|
|
243
|
+
|
|
244
|
+
chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
|
|
245
|
+
rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
246
|
+
|
|
247
|
+
chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
|
|
248
|
+
rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
249
|
+
|
|
250
|
+
return (
|
|
251
|
+
log_probs,
|
|
252
|
+
chosen_logps_sum,
|
|
253
|
+
rejected_logps_sum,
|
|
254
|
+
chosen_logits_sum,
|
|
255
|
+
rejected_logits_sum,
|
|
256
|
+
)
|
|
191
257
|
|
|
192
258
|
@staticmethod
|
|
193
259
|
def _compute_loss(
|
|
@@ -203,6 +269,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
203
269
|
ref_input_chunk=None,
|
|
204
270
|
ref_weight=None,
|
|
205
271
|
ref_bias=None,
|
|
272
|
+
average_log_prob=False,
|
|
206
273
|
**loss_kwargs,
|
|
207
274
|
):
|
|
208
275
|
"""
|
|
@@ -218,29 +285,57 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
218
285
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
219
286
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
220
287
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
288
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
221
289
|
loss_kwargs (dict): Additional arguments for the loss function.
|
|
222
290
|
"""
|
|
223
|
-
|
|
291
|
+
(
|
|
292
|
+
log_prob_chunk,
|
|
293
|
+
chosen_logps_sum,
|
|
294
|
+
rejected_logps_sum,
|
|
295
|
+
chosen_logits_sum,
|
|
296
|
+
rejected_logits_sum,
|
|
297
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
224
298
|
input_chunk,
|
|
225
299
|
weight,
|
|
226
300
|
target_chunk,
|
|
301
|
+
preference_labels_chunk,
|
|
227
302
|
bias=bias,
|
|
228
303
|
ignore_index=ignore_index,
|
|
304
|
+
average_log_prob=average_log_prob,
|
|
229
305
|
)
|
|
230
306
|
|
|
231
307
|
if use_ref_model:
|
|
232
308
|
with torch.no_grad():
|
|
233
|
-
|
|
309
|
+
(
|
|
310
|
+
ref_log_prob_chunk,
|
|
311
|
+
_,
|
|
312
|
+
_,
|
|
313
|
+
_,
|
|
314
|
+
_,
|
|
315
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
234
316
|
ref_input_chunk,
|
|
235
317
|
ref_weight,
|
|
236
318
|
target_chunk,
|
|
319
|
+
preference_labels_chunk,
|
|
237
320
|
ref_bias,
|
|
238
321
|
ignore_index=ignore_index,
|
|
322
|
+
average_log_prob=average_log_prob,
|
|
239
323
|
)
|
|
240
|
-
loss_kwargs["
|
|
324
|
+
loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
|
|
241
325
|
|
|
242
|
-
|
|
243
|
-
|
|
326
|
+
preference_loss_outputs = preference_loss_fn(
|
|
327
|
+
log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
|
328
|
+
)
|
|
329
|
+
if isinstance(preference_loss_outputs, tuple):
|
|
330
|
+
preference_loss_chunk, *aux_outputs = preference_loss_outputs
|
|
331
|
+
else:
|
|
332
|
+
preference_loss_chunk, aux_outputs = preference_loss_outputs, []
|
|
333
|
+
|
|
334
|
+
return_vars = (
|
|
335
|
+
chosen_logps_sum,
|
|
336
|
+
rejected_logps_sum,
|
|
337
|
+
chosen_logits_sum,
|
|
338
|
+
rejected_logits_sum,
|
|
244
339
|
)
|
|
245
340
|
|
|
246
|
-
return preference_loss_chunk
|
|
341
|
+
return preference_loss_chunk, (*return_vars, *aux_outputs)
|
|
@@ -63,8 +63,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
63
63
|
|
|
64
64
|
return loss, metrics
|
|
65
65
|
|
|
66
|
-
@
|
|
66
|
+
@classmethod
|
|
67
67
|
def forward(
|
|
68
|
+
cls,
|
|
68
69
|
ctx,
|
|
69
70
|
_input,
|
|
70
71
|
weight,
|
|
@@ -79,12 +80,12 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
79
80
|
use_ref_model=True,
|
|
80
81
|
num_generations=1,
|
|
81
82
|
):
|
|
82
|
-
return
|
|
83
|
+
return super().forward(
|
|
84
|
+
cls=cls,
|
|
83
85
|
ctx=ctx,
|
|
84
86
|
_input=_input,
|
|
85
87
|
weight=weight,
|
|
86
88
|
attention_mask=attention_mask,
|
|
87
|
-
loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
|
|
88
89
|
rewards=rewards,
|
|
89
90
|
bias=bias,
|
|
90
91
|
ref_input=ref_input,
|