liger-kernel-nightly 0.5.3.dev20250220230230__tar.gz → 0.5.3.dev20250221003838__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/PKG-INFO +4 -2
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/README.md +3 -1
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/data/all_benchmark_data.csv +37 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_kto_loss.py +6 -6
- liger_kernel_nightly-0.5.3.dev20250221003838/benchmark/scripts/benchmark_tvd.py +136 -0
- liger_kernel_nightly-0.5.3.dev20250221003838/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/pyproject.toml +1 -1
- liger_kernel_nightly-0.5.3.dev20250221003838/src/liger_kernel/ops/tvd.py +208 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/__init__.py +1 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/functional.py +15 -1
- liger_kernel_nightly-0.5.3.dev20250221003838/src/liger_kernel/transformers/tvd.py +15 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel_nightly.egg-info/PKG-INFO +4 -2
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
- liger_kernel_nightly-0.5.3.dev20250221003838/test/transformers/test_tvd.py +195 -0
- liger_kernel_nightly-0.5.3.dev20250220230230/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/Makefile +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/setup.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220230230 → liger_kernel_nightly-0.5.3.dev20250221003838}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.3.
|
3
|
+
Version: 0.5.3.dev20250221003838
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -126,7 +126,7 @@ Requires-Dist: mkdocs-material; extra == "dev"
|
|
126
126
|
|
127
127
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
128
128
|
|
129
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
129
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
130
130
|
|
131
131
|
## Supercharge Your Model with Liger Kernel
|
132
132
|
|
@@ -341,6 +341,7 @@ loss.backward()
|
|
341
341
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
342
342
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
343
343
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
344
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
344
345
|
|
345
346
|
### Distillation Kernels
|
346
347
|
|
@@ -349,6 +350,7 @@ loss.backward()
|
|
349
350
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
350
351
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
351
352
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
353
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
352
354
|
|
353
355
|
### Experimental Kernels
|
354
356
|
|
@@ -78,7 +78,7 @@
|
|
78
78
|
|
79
79
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
80
80
|
|
81
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
81
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
82
82
|
|
83
83
|
## Supercharge Your Model with Liger Kernel
|
84
84
|
|
@@ -293,6 +293,7 @@ loss.backward()
|
|
293
293
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
294
294
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
295
295
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
296
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
296
297
|
|
297
298
|
### Distillation Kernels
|
298
299
|
|
@@ -301,6 +302,7 @@ loss.backward()
|
|
301
302
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
302
303
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
303
304
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
305
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
304
306
|
|
305
307
|
### Experimental Kernels
|
306
308
|
|
@@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
|
|
505
505
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
506
506
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
507
507
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
508
|
+
tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
509
|
+
tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
510
|
+
tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
511
|
+
tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
512
|
+
tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
513
|
+
tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
514
|
+
tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
515
|
+
tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
516
|
+
tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
517
|
+
tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
518
|
+
tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
519
|
+
tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
520
|
+
tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
521
|
+
tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
522
|
+
tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
523
|
+
tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
524
|
+
tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
525
|
+
tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
526
|
+
tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
527
|
+
tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
528
|
+
tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
529
|
+
tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
530
|
+
tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
531
|
+
tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
532
|
+
tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
533
|
+
tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
534
|
+
tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
535
|
+
tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
536
|
+
tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
537
|
+
tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
538
|
+
tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
539
|
+
tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
540
|
+
tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
541
|
+
tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
542
|
+
tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
543
|
+
tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
508
544
|
group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
509
545
|
group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
510
546
|
group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
@@ -769,3 +805,4 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
|
|
769
805
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
770
806
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
771
807
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
808
|
+
|
@@ -103,8 +103,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
103
103
|
H=H,
|
104
104
|
V=V,
|
105
105
|
dtype=dtype,
|
106
|
-
|
107
|
-
|
106
|
+
use_bias=bias,
|
107
|
+
use_ref_bias=bias,
|
108
108
|
ignore_index=ignore_index,
|
109
109
|
beta=beta,
|
110
110
|
).to(device)
|
@@ -113,8 +113,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
113
113
|
H=H,
|
114
114
|
V=V,
|
115
115
|
dtype=dtype,
|
116
|
-
|
117
|
-
|
116
|
+
use_bias=bias,
|
117
|
+
use_ref_bias=bias,
|
118
118
|
ignore_index=ignore_index,
|
119
119
|
beta=beta,
|
120
120
|
).to(device)
|
@@ -189,7 +189,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
189
189
|
dtype=dtype,
|
190
190
|
beta=beta,
|
191
191
|
ignore_index=ignore_index,
|
192
|
-
|
192
|
+
use_bias=bias,
|
193
193
|
).to(device)
|
194
194
|
liger_kto_loss = LigerLMHeadKTO(
|
195
195
|
H=H,
|
@@ -197,7 +197,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
197
197
|
dtype=dtype,
|
198
198
|
beta=beta,
|
199
199
|
ignore_index=ignore_index,
|
200
|
-
|
200
|
+
use_bias=bias,
|
201
201
|
).to(device)
|
202
202
|
|
203
203
|
# Input shape: [B, T, H]
|
@@ -0,0 +1,136 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
from utils import (
|
4
|
+
QUANTILES,
|
5
|
+
SingleBenchmarkRunInput,
|
6
|
+
SingleBenchmarkRunOutput,
|
7
|
+
_test_memory,
|
8
|
+
parse_benchmark_script_args,
|
9
|
+
run_benchmarks,
|
10
|
+
)
|
11
|
+
|
12
|
+
from liger_kernel.transformers.tvd import LigerTVDLoss
|
13
|
+
|
14
|
+
|
15
|
+
class TorchTVDLoss(torch.nn.Module):
|
16
|
+
def __init__(self, reduction="batchmean"):
|
17
|
+
super(TorchTVDLoss, self).__init__()
|
18
|
+
self.reduction = reduction
|
19
|
+
|
20
|
+
def forward(self, p, q):
|
21
|
+
tvd = torch.abs(p - q) / 2.0
|
22
|
+
if self.reduction == "mean":
|
23
|
+
return torch.sum(tvd) / (p.size(0) * p.size(1))
|
24
|
+
elif self.reduction == "sum":
|
25
|
+
return torch.sum(tvd)
|
26
|
+
elif self.reduction == "none":
|
27
|
+
return tvd
|
28
|
+
elif self.reduction == "batchmean":
|
29
|
+
return torch.sum(tvd) / p.size(0)
|
30
|
+
else:
|
31
|
+
raise ValueError("Invalid reduction type.")
|
32
|
+
|
33
|
+
|
34
|
+
S, E = 12, 18
|
35
|
+
|
36
|
+
|
37
|
+
def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
38
|
+
reduction = "batchmean"
|
39
|
+
V = input.x
|
40
|
+
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
|
41
|
+
torch_tvd = TorchTVDLoss(reduction=reduction)
|
42
|
+
liger_tvd = LigerTVDLoss(reduction=reduction)
|
43
|
+
|
44
|
+
_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
|
45
|
+
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
|
46
|
+
|
47
|
+
def fwd():
|
48
|
+
if input.kernel_provider == "liger":
|
49
|
+
return liger_tvd(_input, target)
|
50
|
+
else:
|
51
|
+
return torch_tvd(_input, target)
|
52
|
+
|
53
|
+
if input.kernel_operation_mode == "forward":
|
54
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
|
55
|
+
elif input.kernel_operation_mode == "backward":
|
56
|
+
y = fwd()
|
57
|
+
|
58
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
59
|
+
lambda: y.backward(retain_graph=True),
|
60
|
+
quantiles=QUANTILES,
|
61
|
+
grad_to_none=[_input],
|
62
|
+
rep=100,
|
63
|
+
)
|
64
|
+
elif input.kernel_operation_mode == "full":
|
65
|
+
|
66
|
+
def full():
|
67
|
+
y = fwd()
|
68
|
+
y.backward(retain_graph=True)
|
69
|
+
|
70
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
71
|
+
full, quantiles=QUANTILES, rep=100
|
72
|
+
)
|
73
|
+
return SingleBenchmarkRunOutput(
|
74
|
+
y_20=ms_20,
|
75
|
+
y_50=ms_50,
|
76
|
+
y_80=ms_80,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
81
|
+
reduction = "batchmean"
|
82
|
+
torch_tvd = TorchTVDLoss(reduction=reduction)
|
83
|
+
liger_tvd = LigerTVDLoss(reduction=reduction)
|
84
|
+
|
85
|
+
V = input.x
|
86
|
+
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
|
87
|
+
|
88
|
+
_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
|
89
|
+
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
|
90
|
+
|
91
|
+
def fwd():
|
92
|
+
if input.kernel_provider == "liger":
|
93
|
+
return liger_tvd(_input, target)
|
94
|
+
else:
|
95
|
+
return torch_tvd(_input, target)
|
96
|
+
|
97
|
+
def full():
|
98
|
+
y = fwd()
|
99
|
+
y.backward(retain_graph=True)
|
100
|
+
|
101
|
+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
|
102
|
+
|
103
|
+
return SingleBenchmarkRunOutput(
|
104
|
+
y_20=mem_20,
|
105
|
+
y_50=mem_50,
|
106
|
+
y_80=mem_80,
|
107
|
+
)
|
108
|
+
|
109
|
+
|
110
|
+
if __name__ == "__main__":
|
111
|
+
args = parse_benchmark_script_args()
|
112
|
+
common_args = {
|
113
|
+
"kernel_name": "tvd",
|
114
|
+
"x_name": "V",
|
115
|
+
"x_label": "vocab size",
|
116
|
+
"x_values": [2**i for i in range(12, 18)],
|
117
|
+
"kernel_providers": ["liger", "torch"],
|
118
|
+
"extra_benchmark_configs": [{"B": 8, "T": 2048}],
|
119
|
+
"overwrite": args.overwrite,
|
120
|
+
}
|
121
|
+
|
122
|
+
run_benchmarks(
|
123
|
+
bench_test_fn=bench_memory_tvd,
|
124
|
+
kernel_operation_modes=["full"],
|
125
|
+
metric_name="memory",
|
126
|
+
metric_unit="MB",
|
127
|
+
**common_args,
|
128
|
+
)
|
129
|
+
|
130
|
+
run_benchmarks(
|
131
|
+
bench_test_fn=bench_speed_tvd,
|
132
|
+
kernel_operation_modes=["forward", "full"],
|
133
|
+
metric_name="speed",
|
134
|
+
metric_unit="ms",
|
135
|
+
**common_args,
|
136
|
+
)
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.3.
|
7
|
+
version = "0.5.3.dev20250221003838"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -0,0 +1,208 @@
|
|
1
|
+
from typing import Literal, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
|
+
|
9
|
+
MAX_FUSED_SIZE = 65536 // 4
|
10
|
+
|
11
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
12
|
+
|
13
|
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
14
|
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
15
|
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
16
|
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
17
|
+
|
18
|
+
_str_to_reduction_mode = {
|
19
|
+
"none": _REDUCTION_MODE_NONE.value,
|
20
|
+
"sum": _REDUCTION_MODE_SUM.value,
|
21
|
+
"mean": _REDUCTION_MODE_MEAN.value,
|
22
|
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
def get_num_warps(BLOCK_SIZE):
|
27
|
+
num_warps = 4
|
28
|
+
if BLOCK_SIZE >= 32768:
|
29
|
+
num_warps = 32
|
30
|
+
elif BLOCK_SIZE >= 8192:
|
31
|
+
num_warps = 16
|
32
|
+
elif BLOCK_SIZE >= 2048:
|
33
|
+
num_warps = 8
|
34
|
+
|
35
|
+
return num_warps
|
36
|
+
|
37
|
+
|
38
|
+
@triton.jit
|
39
|
+
def _tv_distance_kernel(
|
40
|
+
p_ptr,
|
41
|
+
p_stride,
|
42
|
+
q_ptr,
|
43
|
+
q_stride,
|
44
|
+
loss_ptr,
|
45
|
+
loss_stride,
|
46
|
+
grads_ptr,
|
47
|
+
grads_stride,
|
48
|
+
label_ptr,
|
49
|
+
ignore_index: tl.constexpr,
|
50
|
+
n_cols,
|
51
|
+
BLOCK_SIZE: tl.constexpr,
|
52
|
+
HAS_LABEL: tl.constexpr,
|
53
|
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
54
|
+
):
|
55
|
+
pid = tl.program_id(0).to(tl.int64)
|
56
|
+
p_ptr += pid * p_stride
|
57
|
+
q_ptr += pid * q_stride
|
58
|
+
loss_ptr += pid * loss_stride
|
59
|
+
grads_ptr += pid * grads_stride
|
60
|
+
label_ptr += pid
|
61
|
+
|
62
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
63
|
+
|
64
|
+
if HAS_LABEL:
|
65
|
+
label = tl.load(label_ptr)
|
66
|
+
if label == ignore_index:
|
67
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
68
|
+
offsets = i + base_offsets
|
69
|
+
mask = offsets < n_cols
|
70
|
+
tl.store(grads_ptr + offsets, 0.0, mask=mask)
|
71
|
+
if reduction == _REDUCTION_MODE_NONE:
|
72
|
+
tl.store(loss_ptr + offsets, 0.0, mask=mask)
|
73
|
+
return
|
74
|
+
|
75
|
+
loss_sum = 0.0
|
76
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
77
|
+
offsets = i + base_offsets
|
78
|
+
mask = offsets < n_cols
|
79
|
+
|
80
|
+
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
|
81
|
+
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
|
82
|
+
|
83
|
+
# TVD(P || Q) = 0.5 * |P - Q|
|
84
|
+
tv_loss = 0.5 * tl.abs(p - q)
|
85
|
+
|
86
|
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
87
|
+
|
88
|
+
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
89
|
+
|
90
|
+
if reduction == _REDUCTION_MODE_NONE:
|
91
|
+
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
|
92
|
+
else:
|
93
|
+
loss_sum += tl.sum(tv_loss, axis=0)
|
94
|
+
|
95
|
+
if reduction != _REDUCTION_MODE_NONE:
|
96
|
+
tl.store(loss_ptr, loss_sum)
|
97
|
+
|
98
|
+
|
99
|
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
100
|
+
BT, V = p.shape
|
101
|
+
|
102
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
103
|
+
num_warps = get_num_warps(BLOCK_SIZE)
|
104
|
+
|
105
|
+
grid = (BT,)
|
106
|
+
|
107
|
+
reduction = _str_to_reduction_mode[reduction]
|
108
|
+
|
109
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
110
|
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
111
|
+
grads = torch.empty_like(p)
|
112
|
+
|
113
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
114
|
+
|
115
|
+
_tv_distance_kernel[grid](
|
116
|
+
p,
|
117
|
+
p.stride(0),
|
118
|
+
q,
|
119
|
+
q.stride(0),
|
120
|
+
output_tensor,
|
121
|
+
output_tensor.stride(0),
|
122
|
+
grads,
|
123
|
+
grads.stride(0),
|
124
|
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
125
|
+
ignore_index,
|
126
|
+
V,
|
127
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
128
|
+
HAS_LABEL=has_label,
|
129
|
+
num_warps=num_warps,
|
130
|
+
reduction=reduction,
|
131
|
+
)
|
132
|
+
|
133
|
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
134
|
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
135
|
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
136
|
+
return output_tensor.sum(dim=0), grads
|
137
|
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
138
|
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
139
|
+
else:
|
140
|
+
return output_tensor, grads
|
141
|
+
|
142
|
+
|
143
|
+
def tvd_backward_triton(grad_output, grads):
|
144
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
145
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
146
|
+
return grads
|
147
|
+
|
148
|
+
return grads * grad_output
|
149
|
+
|
150
|
+
|
151
|
+
class LigerTVDLossFunction(torch.autograd.Function):
|
152
|
+
"""
|
153
|
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
154
|
+
"""
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
@ensure_contiguous
|
158
|
+
def forward(
|
159
|
+
ctx,
|
160
|
+
p: torch.Tensor,
|
161
|
+
q: torch.Tensor,
|
162
|
+
shift_labels: Optional[torch.Tensor] = None,
|
163
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
164
|
+
ignore_index: int = -100,
|
165
|
+
) -> torch.Tensor:
|
166
|
+
"""A forward pass for the Total Variation Distance Loss.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
ctx: Torch autograd context
|
170
|
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
171
|
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
172
|
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
173
|
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
174
|
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
torch.Tensor: The computed Total Variation Distance Loss.
|
178
|
+
"""
|
179
|
+
has_label = False
|
180
|
+
if shift_labels is not None:
|
181
|
+
assert shift_labels.shape == (
|
182
|
+
p.shape[0],
|
183
|
+
), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
184
|
+
shift_labels = shift_labels.contiguous()
|
185
|
+
has_label = True
|
186
|
+
|
187
|
+
loss, grads = tv_distance_forward_triton(
|
188
|
+
p, q, shift_labels, reduction, ignore_index, has_label
|
189
|
+
)
|
190
|
+
ctx.save_for_backward(grads)
|
191
|
+
return loss
|
192
|
+
|
193
|
+
@staticmethod
|
194
|
+
@ensure_contiguous
|
195
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
196
|
+
"""A backward pass for the Total Variation Distance Loss.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
ctx: Torch autograd context
|
200
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
204
|
+
"""
|
205
|
+
(grads,) = ctx.saved_tensors
|
206
|
+
grads = tvd_backward_triton(grad_output, grads)
|
207
|
+
|
208
|
+
return grads, None, None, None, None
|
@@ -18,6 +18,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
|
|
18
18
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
19
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
20
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
21
|
+
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
21
22
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
22
23
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
23
24
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
@@ -12,7 +12,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
12
12
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
13
13
|
from liger_kernel.ops.rope import LigerRopeFunction
|
14
14
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
15
|
-
|
15
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
16
16
|
|
17
17
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
18
18
|
# `weight` and `size_average` are placeholders and not implemented yet
|
@@ -156,6 +156,20 @@ def liger_kl_div(
|
|
156
156
|
eps,
|
157
157
|
)
|
158
158
|
|
159
|
+
def liger_tvd(
|
160
|
+
input,
|
161
|
+
target,
|
162
|
+
shift_labels=None,
|
163
|
+
reduction: str = "mean",
|
164
|
+
ignore_index: int = -100,
|
165
|
+
):
|
166
|
+
return LigerTVDLossFunction.apply(
|
167
|
+
input,
|
168
|
+
target,
|
169
|
+
shift_labels,
|
170
|
+
reduction,
|
171
|
+
ignore_index,
|
172
|
+
)
|
159
173
|
|
160
174
|
def liger_layer_norm(X, W, B, eps):
|
161
175
|
return LigerLayerNormFunction.apply(X, W, B, eps)
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
4
|
+
|
5
|
+
|
6
|
+
class LigerTVDLoss(nn.Module):
|
7
|
+
def __init__(self, reduction="batchmean", ignore_index: int = -100):
|
8
|
+
super(LigerTVDLoss, self).__init__()
|
9
|
+
self.reduction = reduction
|
10
|
+
self.ignore_index = ignore_index
|
11
|
+
|
12
|
+
def forward(self, p, q, shift_labels=None):
|
13
|
+
return LigerTVDLossFunction.apply(
|
14
|
+
p, q, shift_labels, self.reduction, self.ignore_index
|
15
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.3.
|
3
|
+
Version: 0.5.3.dev20250221003838
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -126,7 +126,7 @@ Requires-Dist: mkdocs-material; extra == "dev"
|
|
126
126
|
|
127
127
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
128
128
|
|
129
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
129
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
130
130
|
|
131
131
|
## Supercharge Your Model with Liger Kernel
|
132
132
|
|
@@ -341,6 +341,7 @@ loss.backward()
|
|
341
341
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
342
342
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
343
343
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
344
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
344
345
|
|
345
346
|
### Distillation Kernels
|
346
347
|
|
@@ -349,6 +350,7 @@ loss.backward()
|
|
349
350
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
350
351
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
351
352
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
353
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
352
354
|
|
353
355
|
### Experimental Kernels
|
354
356
|
|
@@ -39,6 +39,7 @@ benchmark/scripts/benchmark_rms_norm.py
|
|
39
39
|
benchmark/scripts/benchmark_rope.py
|
40
40
|
benchmark/scripts/benchmark_simpo_loss.py
|
41
41
|
benchmark/scripts/benchmark_swiglu.py
|
42
|
+
benchmark/scripts/benchmark_tvd.py
|
42
43
|
benchmark/scripts/utils.py
|
43
44
|
dev/fmt-requirements.txt
|
44
45
|
dev/modal/tests.py
|
@@ -131,6 +132,7 @@ src/liger_kernel/ops/qwen2vl_mrope.py
|
|
131
132
|
src/liger_kernel/ops/rms_norm.py
|
132
133
|
src/liger_kernel/ops/rope.py
|
133
134
|
src/liger_kernel/ops/swiglu.py
|
135
|
+
src/liger_kernel/ops/tvd.py
|
134
136
|
src/liger_kernel/ops/utils.py
|
135
137
|
src/liger_kernel/ops/experimental/embedding.py
|
136
138
|
src/liger_kernel/ops/experimental/mm_int8int2.py
|
@@ -151,6 +153,7 @@ src/liger_kernel/transformers/rms_norm.py
|
|
151
153
|
src/liger_kernel/transformers/rope.py
|
152
154
|
src/liger_kernel/transformers/swiglu.py
|
153
155
|
src/liger_kernel/transformers/trainer_integration.py
|
156
|
+
src/liger_kernel/transformers/tvd.py
|
154
157
|
src/liger_kernel/transformers/experimental/embedding.py
|
155
158
|
src/liger_kernel/transformers/model/__init__.py
|
156
159
|
src/liger_kernel/transformers/model/gemma.py
|
@@ -216,4 +219,5 @@ test/transformers/test_rope.py
|
|
216
219
|
test/transformers/test_swiglu.py
|
217
220
|
test/transformers/test_trainer_integration.py
|
218
221
|
test/transformers/test_transformers.py
|
222
|
+
test/transformers/test_tvd.py
|
219
223
|
test/triton/test_triton_monkey_patch.py
|