liger-kernel-nightly 0.5.2.dev20250130024630__tar.gz → 0.5.2.dev20250130213846__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/data/all_benchmark_data.csv +24 -0
- liger_kernel_nightly-0.5.2.dev20250130213846/benchmark/scripts/benchmark_distill_jsd_loss.py +261 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/__init__.py +1 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/dpo_loss.py +5 -2
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/functional.py +2 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +14 -5
- liger_kernel_nightly-0.5.2.dev20250130213846/src/liger_kernel/chunked_loss/jsd_loss.py +154 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel_nightly.egg-info/SOURCES.txt +3 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/chunked_loss/test_dpo_loss.py +4 -1
- liger_kernel_nightly-0.5.2.dev20250130213846/test/chunked_loss/test_jsd_loss.py +318 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/utils.py +4 -1
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/Makefile +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/setup.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/convergence/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/convergence/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/convergence/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630 → liger_kernel_nightly-0.5.2.dev20250130213846}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -745,3 +745,27 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.253906
|
|
745
745
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
746
746
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
747
747
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
748
|
+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
749
|
+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
750
|
+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
751
|
+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
752
|
+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
|
753
|
+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
|
754
|
+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
|
755
|
+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
|
756
|
+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
|
757
|
+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
|
758
|
+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
|
759
|
+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
|
760
|
+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
|
761
|
+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
|
762
|
+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
|
763
|
+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
|
764
|
+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
|
765
|
+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
|
766
|
+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
|
767
|
+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
|
768
|
+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
769
|
+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
770
|
+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
771
|
+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
@@ -0,0 +1,261 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import triton
|
6
|
+
|
7
|
+
from utils import QUANTILES
|
8
|
+
from utils import SingleBenchmarkRunInput
|
9
|
+
from utils import SingleBenchmarkRunOutput
|
10
|
+
from utils import _test_memory
|
11
|
+
from utils import parse_benchmark_script_args
|
12
|
+
from utils import run_benchmarks
|
13
|
+
|
14
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
15
|
+
from liger_kernel.utils import infer_device
|
16
|
+
|
17
|
+
device = infer_device()
|
18
|
+
|
19
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
20
|
+
|
21
|
+
|
22
|
+
class TorchJSDLoss(torch.nn.Module):
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
H: int,
|
26
|
+
V: int,
|
27
|
+
dtype: torch.dtype,
|
28
|
+
weight_hard_loss: float = 0.5,
|
29
|
+
weight_soft_loss: float = 0.5,
|
30
|
+
ignore_index: int = -100,
|
31
|
+
temperature: float = 1.0,
|
32
|
+
bias: bool = False,
|
33
|
+
):
|
34
|
+
from test.chunked_loss.test_jsd_loss import HFJSDLoss
|
35
|
+
|
36
|
+
super().__init__()
|
37
|
+
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
|
38
|
+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
39
|
+
self.jsd_loss = HFJSDLoss(
|
40
|
+
ignore_index=ignore_index,
|
41
|
+
weight_hard_loss=weight_hard_loss,
|
42
|
+
weight_soft_loss=weight_soft_loss,
|
43
|
+
temperature=temperature,
|
44
|
+
).get_batch_loss_metrics
|
45
|
+
|
46
|
+
def forward(self, student, teacher, target):
|
47
|
+
return self.jsd_loss(
|
48
|
+
student,
|
49
|
+
self.student_lin.weight,
|
50
|
+
teacher,
|
51
|
+
self.teacher_lin.weight,
|
52
|
+
target,
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class LigerJSDLoss(torch.nn.Module):
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
H: int,
|
60
|
+
V: int,
|
61
|
+
dtype: torch.dtype,
|
62
|
+
weight_hard_loss: float = 0.5,
|
63
|
+
weight_soft_loss: float = 0.5,
|
64
|
+
ignore_index: int = -100,
|
65
|
+
temperature: float = 1.0,
|
66
|
+
bias: bool = False,
|
67
|
+
):
|
68
|
+
super().__init__()
|
69
|
+
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
|
70
|
+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
71
|
+
self.weight_hard_loss = weight_hard_loss
|
72
|
+
self.weight_soft_loss = weight_soft_loss
|
73
|
+
self.ignore_index = ignore_index
|
74
|
+
self.temperature = temperature
|
75
|
+
self.jsd_loss = LigerFusedLinearJSDFunction.apply
|
76
|
+
|
77
|
+
def forward(self, student, teacher, target):
|
78
|
+
return self.jsd_loss(
|
79
|
+
student,
|
80
|
+
self.student_lin.weight,
|
81
|
+
teacher,
|
82
|
+
self.teacher_lin.weight,
|
83
|
+
target,
|
84
|
+
self.weight_hard_loss,
|
85
|
+
self.weight_soft_loss,
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
90
|
+
BT = input.x
|
91
|
+
H = input.extra_benchmark_config["H"]
|
92
|
+
V = input.extra_benchmark_config["V"]
|
93
|
+
dtype = input.extra_benchmark_config["dtype"]
|
94
|
+
bias = input.extra_benchmark_config["bias"]
|
95
|
+
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
|
96
|
+
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
|
97
|
+
ignore_index = input.extra_benchmark_config["ignore_index"]
|
98
|
+
provider = input.kernel_provider
|
99
|
+
|
100
|
+
torch_jsd_loss = TorchJSDLoss(
|
101
|
+
H=H,
|
102
|
+
V=V,
|
103
|
+
dtype=dtype,
|
104
|
+
ignore_index=ignore_index,
|
105
|
+
bias=bias,
|
106
|
+
weight_hard_loss=weight_hard_loss,
|
107
|
+
weight_soft_loss=weight_soft_loss,
|
108
|
+
).to(device)
|
109
|
+
liger_jsd_loss = LigerJSDLoss(
|
110
|
+
H=H,
|
111
|
+
V=V,
|
112
|
+
dtype=dtype,
|
113
|
+
ignore_index=ignore_index,
|
114
|
+
bias=bias,
|
115
|
+
weight_hard_loss=weight_hard_loss,
|
116
|
+
weight_soft_loss=weight_soft_loss,
|
117
|
+
).to(device)
|
118
|
+
|
119
|
+
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
|
120
|
+
student_input1 = _tensor.detach().clone().requires_grad_(True)
|
121
|
+
student_input2 = _tensor.detach().clone().requires_grad_(True)
|
122
|
+
|
123
|
+
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
|
124
|
+
|
125
|
+
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
|
126
|
+
|
127
|
+
def fwd():
|
128
|
+
if provider == "liger":
|
129
|
+
return liger_jsd_loss(student_input1, teacher_input, target)
|
130
|
+
elif provider == "torch":
|
131
|
+
return torch_jsd_loss(student_input2, teacher_input, target)
|
132
|
+
|
133
|
+
def full():
|
134
|
+
y = fwd()
|
135
|
+
y.backward()
|
136
|
+
|
137
|
+
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
|
138
|
+
return SingleBenchmarkRunOutput(
|
139
|
+
y_20=mem_20,
|
140
|
+
y_50=mem_50,
|
141
|
+
y_80=mem_80,
|
142
|
+
)
|
143
|
+
|
144
|
+
|
145
|
+
def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
146
|
+
BT = input.x
|
147
|
+
H = input.extra_benchmark_config["H"]
|
148
|
+
V = input.extra_benchmark_config["V"]
|
149
|
+
dtype = input.extra_benchmark_config["dtype"]
|
150
|
+
bias = input.extra_benchmark_config["bias"]
|
151
|
+
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
|
152
|
+
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
|
153
|
+
ignore_index = input.extra_benchmark_config["ignore_index"]
|
154
|
+
provider = input.kernel_provider
|
155
|
+
mode = input.kernel_operation_mode
|
156
|
+
|
157
|
+
torch_jsd_loss = TorchJSDLoss(
|
158
|
+
H=H,
|
159
|
+
V=V,
|
160
|
+
dtype=dtype,
|
161
|
+
ignore_index=ignore_index,
|
162
|
+
bias=bias,
|
163
|
+
weight_hard_loss=weight_hard_loss,
|
164
|
+
weight_soft_loss=weight_soft_loss,
|
165
|
+
).to(device)
|
166
|
+
liger_jsd_loss = LigerJSDLoss(
|
167
|
+
H=H,
|
168
|
+
V=V,
|
169
|
+
dtype=dtype,
|
170
|
+
ignore_index=ignore_index,
|
171
|
+
bias=bias,
|
172
|
+
weight_hard_loss=weight_hard_loss,
|
173
|
+
weight_soft_loss=weight_soft_loss,
|
174
|
+
).to(device)
|
175
|
+
|
176
|
+
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
|
177
|
+
student_input1 = _tensor.detach().clone().requires_grad_(True)
|
178
|
+
student_input2 = _tensor.detach().clone().requires_grad_(True)
|
179
|
+
|
180
|
+
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
|
181
|
+
|
182
|
+
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
|
183
|
+
|
184
|
+
def fwd():
|
185
|
+
if provider == "liger":
|
186
|
+
return liger_jsd_loss(student_input1, teacher_input, target)
|
187
|
+
elif provider == "torch":
|
188
|
+
return torch_jsd_loss(student_input2, teacher_input, target)
|
189
|
+
|
190
|
+
if mode == "forward":
|
191
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
192
|
+
fwd,
|
193
|
+
rep=100,
|
194
|
+
quantiles=QUANTILES,
|
195
|
+
)
|
196
|
+
elif mode == "backward":
|
197
|
+
y = fwd()
|
198
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
199
|
+
lambda: y.backward(retain_graph=True),
|
200
|
+
grad_to_none=[student_input1, student_input2],
|
201
|
+
rep=100,
|
202
|
+
quantiles=QUANTILES,
|
203
|
+
)
|
204
|
+
elif mode == "full":
|
205
|
+
|
206
|
+
def full():
|
207
|
+
y = fwd()
|
208
|
+
y.backward()
|
209
|
+
|
210
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
211
|
+
full,
|
212
|
+
rep=100,
|
213
|
+
quantiles=QUANTILES,
|
214
|
+
)
|
215
|
+
|
216
|
+
return SingleBenchmarkRunOutput(
|
217
|
+
y_20=ms_20,
|
218
|
+
y_50=ms_50,
|
219
|
+
y_80=ms_80,
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
if __name__ == "__main__":
|
224
|
+
args = parse_benchmark_script_args()
|
225
|
+
|
226
|
+
common_configs = {
|
227
|
+
"kernel_name": "distill_jsd_loss",
|
228
|
+
"x_name": "BT",
|
229
|
+
"x_label": "B x T",
|
230
|
+
"x_values": [2**i for i in range(10, 14)],
|
231
|
+
"kernel_providers": ["liger", "torch"],
|
232
|
+
"extra_benchmark_configs": [
|
233
|
+
{
|
234
|
+
"H": 4096,
|
235
|
+
"V": 128256,
|
236
|
+
"mode": "forward",
|
237
|
+
"dtype": torch.bfloat16,
|
238
|
+
"bias": False,
|
239
|
+
"weight_hard_loss": 0.5,
|
240
|
+
"weight_soft_loss": 0.5,
|
241
|
+
"ignore_index": -100,
|
242
|
+
}
|
243
|
+
],
|
244
|
+
"overwrite": args.overwrite,
|
245
|
+
}
|
246
|
+
|
247
|
+
run_benchmarks(
|
248
|
+
bench_test_fn=bench_speed_jsd_loss,
|
249
|
+
kernel_operation_modes=["forward", "full"],
|
250
|
+
metric_name="speed",
|
251
|
+
metric_unit="ms",
|
252
|
+
**common_configs,
|
253
|
+
)
|
254
|
+
|
255
|
+
run_benchmarks(
|
256
|
+
bench_test_fn=bench_memory_jsd_loss,
|
257
|
+
kernel_operation_modes=["full"],
|
258
|
+
metric_name="memory",
|
259
|
+
metric_unit="MB",
|
260
|
+
**common_configs,
|
261
|
+
)
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.2.
|
7
|
+
version = "0.5.2.dev20250130213846"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
3
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
|
3
4
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
4
5
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
5
6
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
@@ -45,9 +45,12 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
45
45
|
chosen_logratios = chosen_logps - ref_chosen_logps
|
46
46
|
rejected_logratios = rejected_logps - ref_rejected_logps
|
47
47
|
|
48
|
+
chosen_rewards = beta * (chosen_logps - ref_chosen_logps)
|
49
|
+
rejected_rewards = beta * (rejected_logps - ref_rejected_logps)
|
50
|
+
|
48
51
|
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
49
52
|
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
50
|
-
return loss
|
53
|
+
return loss, chosen_rewards, rejected_rewards
|
51
54
|
|
52
55
|
@staticmethod
|
53
56
|
def forward(
|
@@ -99,7 +102,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
99
102
|
beta: float = 0.1,
|
100
103
|
compute_nll_loss: bool = False,
|
101
104
|
compiled: bool = True,
|
102
|
-
use_ref_model: bool =
|
105
|
+
use_ref_model: bool = True,
|
103
106
|
):
|
104
107
|
"""
|
105
108
|
Args:
|
@@ -1,11 +1,13 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
3
4
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
4
5
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
5
6
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
6
7
|
|
7
8
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
8
9
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
10
|
+
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
9
11
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
10
12
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
11
13
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
@@ -17,6 +17,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
17
17
|
Args:
|
18
18
|
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
19
19
|
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
20
|
+
Returns:
|
21
|
+
torch.Tensor: Sum of distillation losses for the chunk. The class will handle
|
22
|
+
converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
|
20
23
|
"""
|
21
24
|
raise NotImplementedError("Distillation loss function must be implemented.")
|
22
25
|
|
@@ -71,10 +74,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
71
74
|
weight_hard_loss=0.5,
|
72
75
|
weight_soft_loss=0.5,
|
73
76
|
compute_ce_loss=True,
|
77
|
+
temperature=1,
|
74
78
|
**loss_kwargs,
|
75
79
|
):
|
76
80
|
"""
|
77
|
-
Compute the total loss for a chunk of input and target, while using an
|
81
|
+
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
|
78
82
|
Args:
|
79
83
|
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
80
84
|
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
|
@@ -84,11 +88,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
84
88
|
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
|
85
89
|
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
86
90
|
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
87
|
-
full_target (torch.Tensor): Full target tensor. Shape: (
|
91
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
|
88
92
|
ignore_index (int): Index to ignore for loss computation.
|
89
93
|
weight_hard_loss (float): Weight for hard loss.
|
90
94
|
weight_soft_loss (float): Weight for soft loss.
|
91
95
|
compute_ce_loss (bool): Whether to compute CE loss.
|
96
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
92
97
|
loss_kwargs (dict): Additional arguments for the loss function.
|
93
98
|
"""
|
94
99
|
(
|
@@ -107,6 +112,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
107
112
|
compute_ce_loss=compute_ce_loss,
|
108
113
|
)
|
109
114
|
|
115
|
+
student_logits_chunk /= temperature
|
116
|
+
teacher_logits_chunk /= temperature
|
117
|
+
|
110
118
|
hard_loss /= full_target.shape[0]
|
111
119
|
|
112
120
|
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
|
@@ -130,6 +138,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
130
138
|
ignore_index=-100,
|
131
139
|
weight_hard_loss=0.5,
|
132
140
|
weight_soft_loss=0.5,
|
141
|
+
beta=0.5,
|
133
142
|
compute_ce_loss=True,
|
134
143
|
temperature=1.0,
|
135
144
|
compiled=True,
|
@@ -152,6 +161,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
152
161
|
ignore_index (int): Index to ignore for loss computation.
|
153
162
|
weight_hard_loss (float): Weight for hard/task loss.
|
154
163
|
weight_soft_loss (float): Weight for soft/distillation loss.
|
164
|
+
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
|
155
165
|
compute_ce_loss (bool): Whether to compute CE loss.
|
156
166
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
157
167
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
@@ -170,7 +180,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
170
180
|
ignore_index=ignore_index,
|
171
181
|
weight_hard_loss=weight_hard_loss,
|
172
182
|
weight_soft_loss=weight_soft_loss,
|
183
|
+
beta=beta,
|
173
184
|
compute_ce_loss=compute_ce_loss,
|
185
|
+
temperature=temperature,
|
174
186
|
**loss_kwargs,
|
175
187
|
)
|
176
188
|
|
@@ -225,9 +237,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
225
237
|
if compiled:
|
226
238
|
accumulate_chunk = torch.compile(accumulate_chunk)
|
227
239
|
|
228
|
-
student_input /= temperature
|
229
|
-
teacher_input /= temperature
|
230
|
-
|
231
240
|
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
|
232
241
|
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
|
233
242
|
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
5
|
+
|
6
|
+
|
7
|
+
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
8
|
+
@staticmethod
|
9
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
10
|
+
"""
|
11
|
+
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
12
|
+
Args:
|
13
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
14
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
15
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
16
|
+
Returns:
|
17
|
+
torch.Tensor: Jensen-Shannon Divergence loss
|
18
|
+
"""
|
19
|
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
20
|
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
21
|
+
|
22
|
+
# Compute probabilities (only required for mean calculation)
|
23
|
+
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
|
24
|
+
log_mean_probs = mean_probs.log()
|
25
|
+
|
26
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
27
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
28
|
+
|
29
|
+
# JSD is the weighted average of the KL divergences
|
30
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
31
|
+
return jsd_loss
|
32
|
+
|
33
|
+
@staticmethod
|
34
|
+
def forward(
|
35
|
+
ctx,
|
36
|
+
student_input: torch.Tensor,
|
37
|
+
student_weight: torch.Tensor,
|
38
|
+
teacher_input: torch.Tensor,
|
39
|
+
teacher_weight: torch.Tensor,
|
40
|
+
true_labels: torch.LongTensor,
|
41
|
+
weight_hard_loss: float = 0.5,
|
42
|
+
weight_soft_loss: float = 0.5,
|
43
|
+
beta: float = 0.5,
|
44
|
+
ignore_index: int = -100,
|
45
|
+
temperature: float = 1.0,
|
46
|
+
compiled: bool = True,
|
47
|
+
):
|
48
|
+
"""
|
49
|
+
Fused linear layer with JSD distillation loss.
|
50
|
+
Args:
|
51
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
|
52
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
|
53
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
|
54
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
|
55
|
+
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
56
|
+
weight_hard_loss (float): Weight for hard loss.
|
57
|
+
weight_soft_loss (float): Weight for soft loss.
|
58
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
59
|
+
ignore_index (int): Index to ignore in loss computation
|
60
|
+
temperature (float): Temperature for softening/sharpening distributions
|
61
|
+
compiled (bool): Whether to use torch compile
|
62
|
+
Returns:
|
63
|
+
torch.Tensor: Computed loss
|
64
|
+
"""
|
65
|
+
return LigerFusedLinearDistillationBase.forward(
|
66
|
+
ctx=ctx,
|
67
|
+
student_input=student_input,
|
68
|
+
student_weight=student_weight,
|
69
|
+
teacher_input=teacher_input,
|
70
|
+
teacher_weight=teacher_weight,
|
71
|
+
target=true_labels,
|
72
|
+
loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
|
73
|
+
chunk_size=1,
|
74
|
+
weight_hard_loss=weight_hard_loss,
|
75
|
+
weight_soft_loss=weight_soft_loss,
|
76
|
+
beta=beta,
|
77
|
+
ignore_index=ignore_index,
|
78
|
+
temperature=temperature,
|
79
|
+
compiled=compiled,
|
80
|
+
)
|
81
|
+
|
82
|
+
@staticmethod
|
83
|
+
def backward(ctx, grad_output):
|
84
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
|
85
|
+
|
86
|
+
return (*grads, None, None, None, None, None, None, None)
|
87
|
+
|
88
|
+
|
89
|
+
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
90
|
+
"""
|
91
|
+
Fused linear layer with JSD distillation loss.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
weight_hard_loss: float = 0.5,
|
97
|
+
weight_soft_loss: float = 0.5,
|
98
|
+
beta: float = 0.5,
|
99
|
+
ignore_index: int = -100,
|
100
|
+
temperature: float = 1.0,
|
101
|
+
compiled: bool = True,
|
102
|
+
):
|
103
|
+
"""
|
104
|
+
Args:
|
105
|
+
weight_hard_loss (float): Weight for hard loss.
|
106
|
+
weight_soft_loss (float): Weight for soft loss.
|
107
|
+
ignore_index (int): Index to ignore in the loss
|
108
|
+
temperature (float): Temperature for softening distributions
|
109
|
+
compiled (bool): Whether to use torch compile
|
110
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
111
|
+
"""
|
112
|
+
super().__init__()
|
113
|
+
assert temperature != 0, "Temperature cannot be 0."
|
114
|
+
self.weight_hard_loss = weight_hard_loss
|
115
|
+
self.weight_soft_loss = weight_soft_loss
|
116
|
+
self.ignore_index = ignore_index
|
117
|
+
self.temperature = temperature
|
118
|
+
self.compiled = compiled
|
119
|
+
self.beta = beta
|
120
|
+
|
121
|
+
def forward(
|
122
|
+
self,
|
123
|
+
student_input: torch.Tensor,
|
124
|
+
student_weight: torch.Tensor,
|
125
|
+
teacher_input: torch.Tensor,
|
126
|
+
teacher_weight: torch.Tensor,
|
127
|
+
true_labels: torch.LongTensor,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
Compute the JSD distillation loss.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
student_input (torch.Tensor): Student input tensor
|
134
|
+
student_weight (torch.Tensor): Student weight tensor
|
135
|
+
teacher_input (torch.Tensor): Teacher input tensor
|
136
|
+
teacher_weight (torch.Tensor): Teacher weight tensor
|
137
|
+
true_labels (torch.LongTensor): Target labels tensor
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
torch.Tensor: Computed loss
|
141
|
+
"""
|
142
|
+
return LigerFusedLinearJSDFunction.apply(
|
143
|
+
student_input,
|
144
|
+
student_weight,
|
145
|
+
teacher_input,
|
146
|
+
teacher_weight,
|
147
|
+
true_labels,
|
148
|
+
self.weight_hard_loss,
|
149
|
+
self.weight_soft_loss,
|
150
|
+
self.beta,
|
151
|
+
self.ignore_index,
|
152
|
+
self.temperature,
|
153
|
+
self.compiled,
|
154
|
+
)
|
@@ -21,6 +21,7 @@ benchmark/data/all_benchmark_data.csv
|
|
21
21
|
benchmark/scripts/__init__.py
|
22
22
|
benchmark/scripts/benchmark_cpo_loss.py
|
23
23
|
benchmark/scripts/benchmark_cross_entropy.py
|
24
|
+
benchmark/scripts/benchmark_distill_jsd_loss.py
|
24
25
|
benchmark/scripts/benchmark_dpo_loss.py
|
25
26
|
benchmark/scripts/benchmark_embedding.py
|
26
27
|
benchmark/scripts/benchmark_fused_linear_cross_entropy.py
|
@@ -110,6 +111,7 @@ src/liger_kernel/chunked_loss/functional.py
|
|
110
111
|
src/liger_kernel/chunked_loss/fused_linear_distillation.py
|
111
112
|
src/liger_kernel/chunked_loss/fused_linear_preference.py
|
112
113
|
src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py
|
114
|
+
src/liger_kernel/chunked_loss/jsd_loss.py
|
113
115
|
src/liger_kernel/chunked_loss/kto_loss.py
|
114
116
|
src/liger_kernel/chunked_loss/orpo_loss.py
|
115
117
|
src/liger_kernel/chunked_loss/simpo_loss.py
|
@@ -172,6 +174,7 @@ test/utils.py
|
|
172
174
|
test/chunked_loss/__init__.py
|
173
175
|
test/chunked_loss/test_cpo_loss.py
|
174
176
|
test/chunked_loss/test_dpo_loss.py
|
177
|
+
test/chunked_loss/test_jsd_loss.py
|
175
178
|
test/chunked_loss/test_kto_loss.py
|
176
179
|
test/chunked_loss/test_orpo_loss.py
|
177
180
|
test/chunked_loss/test_simpo_loss.py
|
@@ -56,9 +56,12 @@ class HFDPOLoss(HFAlignmentLoss):
|
|
56
56
|
chosen_logratios = policy_chosen_logps - ref_chosen_logps
|
57
57
|
rejected_logratios = policy_rejected_logps - ref_rejected_logps
|
58
58
|
|
59
|
+
chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
|
60
|
+
rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)
|
61
|
+
|
59
62
|
logits_diff = self.beta * (chosen_logratios - rejected_logratios)
|
60
63
|
losses = -F.logsigmoid(logits_diff)
|
61
|
-
return losses
|
64
|
+
return losses, chosen_rewards, rejected_rewards
|
62
65
|
|
63
66
|
|
64
67
|
class TorchLMHeadDPO(torch.nn.Module):
|