liger-kernel-nightly 0.5.2.dev20241229035411__tar.gz → 0.5.2.dev20241229131950__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/cross_entropy.py +88 -19
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/fused_linear_cross_entropy.py +50 -29
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/cross_entropy.py +3 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/functional.py +3 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +3 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_cross_entropy.py +184 -3
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_fused_linear_cross_entropy.py +15 -8
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/Makefile +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/Acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/CONTRIBUTING.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/License.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/setup.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/convergence/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/convergence/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/convergence/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229035411 → liger_kernel_nightly-0.5.2.dev20241229131950}/test/utils.py +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.2.
|
7
|
+
version = "0.5.2.dev20241229131950"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -30,11 +30,14 @@ def liger_cross_entropy_kernel(
|
|
30
30
|
X_stride,
|
31
31
|
Y_ptr,
|
32
32
|
Y_stride,
|
33
|
+
weight_ptr,
|
33
34
|
loss_ptr,
|
34
35
|
z_loss_ptr,
|
35
36
|
loss_stride,
|
36
37
|
n_cols,
|
37
38
|
n_non_ignore,
|
39
|
+
sum_non_ignore_weight,
|
40
|
+
weight_sum,
|
38
41
|
ignore_index,
|
39
42
|
lse_square_scale: tl.constexpr,
|
40
43
|
label_smoothing: tl.constexpr,
|
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
42
45
|
softcap,
|
43
46
|
RETURN_Z_LOSS: tl.constexpr,
|
44
47
|
BLOCK_SIZE: tl.constexpr,
|
48
|
+
HAS_WEIGHT: tl.constexpr,
|
45
49
|
HAS_SOFTCAPPING: tl.constexpr,
|
46
50
|
):
|
47
51
|
"""
|
@@ -53,18 +57,22 @@ def liger_cross_entropy_kernel(
|
|
53
57
|
X_stride (int): The stride of the input tensor.
|
54
58
|
Y_ptr: Pointer to target tensor.
|
55
59
|
Y_stride (int): The stride of the target tensor.
|
60
|
+
weight_ptr: Pointer to weight tensor.
|
56
61
|
loss_ptr: Pointer to tensor to store the loss.
|
57
62
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
58
63
|
loss_stride (int): The stride of the loss tensor.
|
59
64
|
n_cols (int): The number of columns in the input tensor.
|
60
|
-
n_non_ignore (
|
65
|
+
n_non_ignore (flaot): The number of non-ignored elements in the batch.
|
66
|
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
67
|
+
weight_sum (float): The sum of weight tensor.
|
61
68
|
ignore_index (int): The index to ignore in the target.
|
62
69
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
63
70
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
64
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
65
71
|
reduction (str): The string for the reduction to apply
|
66
72
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
73
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
67
74
|
BLOCK_SIZE (int): The block size for Triton operations.
|
75
|
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
68
76
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
69
77
|
"""
|
70
78
|
|
@@ -89,6 +97,9 @@ def liger_cross_entropy_kernel(
|
|
89
97
|
loss_ptr += program_id * loss_stride
|
90
98
|
z_loss_ptr += program_id * loss_stride
|
91
99
|
|
100
|
+
if HAS_WEIGHT:
|
101
|
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
102
|
+
|
92
103
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
93
104
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
94
105
|
|
@@ -117,7 +128,11 @@ def liger_cross_entropy_kernel(
|
|
117
128
|
block_max = tl.max(X_block)
|
118
129
|
if label_smoothing > 0:
|
119
130
|
# scale X beforehand to avoid overflow
|
120
|
-
|
131
|
+
if HAS_WEIGHT:
|
132
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
133
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
134
|
+
else:
|
135
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
121
136
|
m_new = tl.maximum(m, block_max)
|
122
137
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
123
138
|
m = m_new
|
@@ -153,18 +168,41 @@ def liger_cross_entropy_kernel(
|
|
153
168
|
if HAS_SOFTCAPPING:
|
154
169
|
intermediate = tanh(X_block / softcap)
|
155
170
|
X_block = softcap * intermediate
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
171
|
+
|
172
|
+
if not HAS_WEIGHT:
|
173
|
+
# softmax(x_i)
|
174
|
+
X_block = tl.exp(X_block - m) / d
|
175
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
176
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
177
|
+
# smoothing term
|
178
|
+
X_block += -eps
|
179
|
+
# special handle dx_y
|
180
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
181
|
+
# reduction scale
|
182
|
+
if reduction == "mean":
|
183
|
+
X_block = X_block / n_non_ignore
|
184
|
+
else:
|
185
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
186
|
+
softmax_X = tl.exp(X_block - m) / d
|
187
|
+
# derivative of original_loss
|
188
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
189
|
+
# specially handle dx_y
|
190
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
191
|
+
dloss_ori = dloss_ori * weight_y
|
192
|
+
# derivative of smooth_loss
|
193
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
194
|
+
# derivative of z-loss
|
195
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
196
|
+
# reduction scale
|
197
|
+
if reduction == "mean":
|
198
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
199
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
200
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
201
|
+
dz_loss = dz_loss / n_non_ignore
|
202
|
+
# derivative of total_loss
|
203
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
204
|
+
|
205
|
+
# chain rule softcapping
|
168
206
|
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
169
207
|
if HAS_SOFTCAPPING:
|
170
208
|
X_block = X_block * (1 - intermediate * intermediate)
|
@@ -183,6 +221,8 @@ def liger_cross_entropy_kernel(
|
|
183
221
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
184
222
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
185
223
|
loss = lse - ori_X_y
|
224
|
+
if HAS_WEIGHT:
|
225
|
+
loss = weight_y * loss
|
186
226
|
|
187
227
|
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
188
228
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
@@ -193,17 +233,24 @@ def liger_cross_entropy_kernel(
|
|
193
233
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
194
234
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
195
235
|
if label_smoothing > 0:
|
196
|
-
|
236
|
+
if HAS_WEIGHT:
|
237
|
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
238
|
+
else:
|
239
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
197
240
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
198
241
|
|
199
242
|
# An auxiliary loss, z_loss
|
200
243
|
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
201
244
|
z_loss = lse_square_scale * lse * lse
|
202
|
-
loss += z_loss
|
203
245
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
204
246
|
if reduction == "mean":
|
247
|
+
if HAS_WEIGHT:
|
248
|
+
loss = loss / sum_non_ignore_weight
|
249
|
+
else:
|
250
|
+
loss = loss / n_non_ignore
|
251
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
205
252
|
z_loss = z_loss / n_non_ignore
|
206
|
-
|
253
|
+
loss += z_loss
|
207
254
|
|
208
255
|
tl.store(loss_ptr, loss)
|
209
256
|
if RETURN_Z_LOSS == _TRUE:
|
@@ -225,6 +272,7 @@ _bool_to_return_z_loss = {
|
|
225
272
|
def cross_entropy_forward(
|
226
273
|
_input,
|
227
274
|
target,
|
275
|
+
weight,
|
228
276
|
ignore_index,
|
229
277
|
lse_square_scale,
|
230
278
|
label_smoothing,
|
@@ -250,7 +298,20 @@ def cross_entropy_forward(
|
|
250
298
|
else:
|
251
299
|
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
|
252
300
|
|
253
|
-
|
301
|
+
target_mask = target != ignore_index
|
302
|
+
n_non_ignore = target_mask.sum().item()
|
303
|
+
sum_non_ignore_weight = n_non_ignore
|
304
|
+
weight_sum = 0.0
|
305
|
+
if weight is not None:
|
306
|
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
307
|
+
assert torch.is_floating_point(
|
308
|
+
weight
|
309
|
+
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
310
|
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
311
|
+
weight_sum = weight.sum().item()
|
312
|
+
# ensure weight is contiguous
|
313
|
+
if weight.stride(-1) != 1:
|
314
|
+
weight = weight.contiguous()
|
254
315
|
|
255
316
|
# ensure _input and target are contiguous in the last dimension
|
256
317
|
if _input.stride(-1) != 1:
|
@@ -264,18 +325,22 @@ def cross_entropy_forward(
|
|
264
325
|
X_stride=_input.stride(-2),
|
265
326
|
Y_ptr=target,
|
266
327
|
Y_stride=target.stride(-1), # always 1
|
328
|
+
weight_ptr=weight if weight is not None else _input, # dummy if None
|
267
329
|
loss_ptr=loss_1d,
|
268
330
|
z_loss_ptr=z_loss_1d,
|
269
331
|
loss_stride=loss_1d.stride(-1), # always 1
|
270
332
|
n_cols=V,
|
271
333
|
n_non_ignore=n_non_ignore,
|
334
|
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
272
335
|
ignore_index=ignore_index,
|
336
|
+
weight_sum=weight_sum,
|
273
337
|
lse_square_scale=lse_square_scale,
|
274
338
|
label_smoothing=label_smoothing,
|
275
339
|
reduction=reduction,
|
276
340
|
softcap=softcap if softcap is not None else 0.0,
|
277
341
|
RETURN_Z_LOSS=return_z_loss,
|
278
342
|
BLOCK_SIZE=BLOCK_SIZE,
|
343
|
+
HAS_WEIGHT=True if weight is not None else False,
|
279
344
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
280
345
|
# TODO: 32 seems to give the best performance
|
281
346
|
# Performance is quite sensitive to num_warps
|
@@ -327,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
327
392
|
ctx,
|
328
393
|
_input: torch.Tensor,
|
329
394
|
target: torch.Tensor,
|
395
|
+
weight: Optional[torch.FloatTensor],
|
330
396
|
ignore_index: int = -100,
|
331
397
|
lse_square_scale: float = 0.0,
|
332
398
|
label_smoothing: float = 0.0,
|
@@ -341,6 +407,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
341
407
|
ctx : The context object.
|
342
408
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
343
409
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
410
|
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
344
411
|
ignore_index (int): The index to ignore in the target.
|
345
412
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
346
413
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
@@ -354,6 +421,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
354
421
|
loss, z_loss, _input = cross_entropy_forward(
|
355
422
|
_input,
|
356
423
|
target,
|
424
|
+
weight,
|
357
425
|
ignore_index,
|
358
426
|
lse_square_scale,
|
359
427
|
label_smoothing,
|
@@ -395,4 +463,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
395
463
|
None,
|
396
464
|
None,
|
397
465
|
None,
|
466
|
+
None,
|
398
467
|
)
|
@@ -17,6 +17,7 @@ def fused_linear_cross_entropy_forward(
|
|
17
17
|
_input,
|
18
18
|
weight,
|
19
19
|
target,
|
20
|
+
ce_weight=None,
|
20
21
|
bias=None,
|
21
22
|
ignore_index=-100,
|
22
23
|
lse_square_scale=0.0,
|
@@ -47,8 +48,22 @@ def fused_linear_cross_entropy_forward(
|
|
47
48
|
# we use fp32 for loss accumulator
|
48
49
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
49
50
|
|
50
|
-
#
|
51
|
-
|
51
|
+
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
52
|
+
target_mask = target != ignore_index
|
53
|
+
total_n_non_ignore = target_mask.sum().item()
|
54
|
+
total_sum_non_ignore_ce_weight = total_n_non_ignore
|
55
|
+
ce_weight_sum = 0.0
|
56
|
+
if ce_weight is not None:
|
57
|
+
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
58
|
+
assert torch.is_floating_point(
|
59
|
+
ce_weight
|
60
|
+
), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
61
|
+
total_sum_non_ignore_ce_weight = (
|
62
|
+
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
63
|
+
)
|
64
|
+
ce_weight_sum = ce_weight.sum().item()
|
65
|
+
if ce_weight.stride(-1) != 1:
|
66
|
+
ce_weight = ce_weight.contiguous()
|
52
67
|
|
53
68
|
for chunk_id in range(num_chunks):
|
54
69
|
start_idx = chunk_id * chunk_size
|
@@ -66,7 +81,6 @@ def fused_linear_cross_entropy_forward(
|
|
66
81
|
|
67
82
|
# unreduced loss
|
68
83
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
69
|
-
n_non_ignore = (target_chunk != ignore_index).sum().item()
|
70
84
|
|
71
85
|
# ensure _input and target are contiguous
|
72
86
|
logits_chunk = logits_chunk.contiguous()
|
@@ -78,35 +92,28 @@ def fused_linear_cross_entropy_forward(
|
|
78
92
|
X_stride=logits_chunk.stride(-2),
|
79
93
|
Y_ptr=target_chunk,
|
80
94
|
Y_stride=target_chunk.stride(-1), # always 1
|
95
|
+
weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
|
81
96
|
loss_ptr=loss_1d_slice,
|
82
97
|
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
|
83
98
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
84
99
|
n_cols=V,
|
85
|
-
n_non_ignore=
|
100
|
+
n_non_ignore=total_n_non_ignore,
|
101
|
+
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
102
|
+
weight_sum=ce_weight_sum,
|
86
103
|
ignore_index=ignore_index,
|
87
104
|
lse_square_scale=lse_square_scale,
|
88
105
|
label_smoothing=label_smoothing,
|
89
106
|
reduction=reduction,
|
90
107
|
softcap=softcap if softcap is not None else 0.0,
|
91
108
|
RETURN_Z_LOSS=0, # False
|
109
|
+
HAS_WEIGHT=True if ce_weight is not None else False,
|
92
110
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
93
111
|
BLOCK_SIZE=BLOCK_SIZE,
|
94
112
|
num_warps=32 if not is_hip() else 16,
|
95
113
|
)
|
96
114
|
|
97
|
-
|
98
|
-
|
99
|
-
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
100
|
-
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
|
101
|
-
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
|
102
|
-
|
103
|
-
if reduction == "mean":
|
104
|
-
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
|
105
|
-
else:
|
106
|
-
alpha = 1.0
|
107
|
-
|
108
|
-
loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
|
109
|
-
grad_logits_chunk = logits_chunk * alpha # chunk_size x V
|
115
|
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
116
|
+
grad_logits_chunk = logits_chunk # chunk_size x V
|
110
117
|
|
111
118
|
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
112
119
|
|
@@ -118,7 +125,7 @@ def fused_linear_cross_entropy_forward(
|
|
118
125
|
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
119
126
|
mat2=_input_chunk,
|
120
127
|
out=grad_weight,
|
121
|
-
alpha=
|
128
|
+
alpha=1.0,
|
122
129
|
beta=1.0,
|
123
130
|
)
|
124
131
|
|
@@ -127,7 +134,7 @@ def fused_linear_cross_entropy_forward(
|
|
127
134
|
input=grad_bias,
|
128
135
|
other=logits_chunk.sum(dim=0),
|
129
136
|
out=grad_bias,
|
130
|
-
alpha=
|
137
|
+
alpha=1.0,
|
131
138
|
)
|
132
139
|
|
133
140
|
if reduction == "none":
|
@@ -193,6 +200,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
193
200
|
weight,
|
194
201
|
target,
|
195
202
|
bias=None,
|
203
|
+
ce_weight=None,
|
196
204
|
ignore_index=-100,
|
197
205
|
lse_square_scale=0.0,
|
198
206
|
label_smoothing=0.0,
|
@@ -212,21 +220,23 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
212
220
|
target: (B*T) where each value is in [0, V-1]
|
213
221
|
weight: (V, H) where V is the number of classes
|
214
222
|
bias: (V) where V is the number of classes
|
223
|
+
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
215
224
|
ignore_index: the index to ignore in the target
|
216
225
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
217
226
|
reduction: reduction to apply
|
218
227
|
"""
|
219
228
|
|
220
229
|
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
221
|
-
_input,
|
222
|
-
weight,
|
223
|
-
target,
|
224
|
-
bias,
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
+
_input=_input,
|
231
|
+
weight=weight,
|
232
|
+
target=target,
|
233
|
+
bias=bias,
|
234
|
+
ce_weight=ce_weight,
|
235
|
+
ignore_index=ignore_index,
|
236
|
+
lse_square_scale=lse_square_scale,
|
237
|
+
label_smoothing=label_smoothing,
|
238
|
+
reduction=reduction,
|
239
|
+
softcap=softcap,
|
230
240
|
)
|
231
241
|
# downcast to dtype and store for backward
|
232
242
|
ctx.save_for_backward(
|
@@ -243,4 +253,15 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
243
253
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
244
254
|
grad_output, grad_input, grad_weight, grad_bias
|
245
255
|
)
|
246
|
-
return (
|
256
|
+
return (
|
257
|
+
grad_input,
|
258
|
+
grad_weight,
|
259
|
+
None,
|
260
|
+
grad_bias,
|
261
|
+
None,
|
262
|
+
None,
|
263
|
+
None,
|
264
|
+
None,
|
265
|
+
None,
|
266
|
+
None,
|
267
|
+
)
|
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
8
8
|
class LigerCrossEntropyLoss(torch.nn.Module):
|
9
9
|
def __init__(
|
10
10
|
self,
|
11
|
+
weight: Optional[torch.FloatTensor] = None,
|
11
12
|
ignore_index: int = -100,
|
12
13
|
lse_square_scale: float = 0.0,
|
13
14
|
label_smoothing: float = 0.0,
|
@@ -28,6 +29,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
28
29
|
"none",
|
29
30
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
30
31
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
32
|
+
self.weight = weight
|
31
33
|
self.ignore_index = ignore_index
|
32
34
|
self.lse_square_scale = lse_square_scale
|
33
35
|
self.label_smoothing = label_smoothing
|
@@ -39,6 +41,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
39
41
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
40
42
|
_input,
|
41
43
|
target,
|
44
|
+
self.weight,
|
42
45
|
self.ignore_index,
|
43
46
|
self.lse_square_scale,
|
44
47
|
self.label_smoothing,
|
@@ -32,6 +32,7 @@ def liger_cross_entropy(
|
|
32
32
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
33
33
|
input,
|
34
34
|
target,
|
35
|
+
weight,
|
35
36
|
ignore_index,
|
36
37
|
lse_square_scale,
|
37
38
|
label_smoothing,
|
@@ -49,6 +50,7 @@ def liger_fused_linear_cross_entropy(
|
|
49
50
|
weight,
|
50
51
|
target,
|
51
52
|
bias=None,
|
53
|
+
ce_weight=None,
|
52
54
|
ignore_index: int = -100,
|
53
55
|
lse_square_scale: float = 0.0,
|
54
56
|
label_smoothing: float = 0.0,
|
@@ -60,6 +62,7 @@ def liger_fused_linear_cross_entropy(
|
|
60
62
|
weight,
|
61
63
|
target,
|
62
64
|
bias,
|
65
|
+
ce_weight,
|
63
66
|
ignore_index,
|
64
67
|
lse_square_scale,
|
65
68
|
label_smoothing,
|
@@ -8,6 +8,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEnt
|
|
8
8
|
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
9
9
|
def __init__(
|
10
10
|
self,
|
11
|
+
ce_weight: Optional[torch.FloatTensor] = None,
|
11
12
|
ignore_index: int = -100,
|
12
13
|
lse_square_scale: float = 0.0,
|
13
14
|
label_smoothing: float = 0.0,
|
@@ -24,6 +25,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
24
25
|
"none",
|
25
26
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
26
27
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
28
|
+
self.ce_weight = ce_weight
|
27
29
|
self.ignore_index = ignore_index
|
28
30
|
self.lse_square_scale = lse_square_scale
|
29
31
|
self.label_smoothing = label_smoothing
|
@@ -36,6 +38,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
36
38
|
lin_weight,
|
37
39
|
target,
|
38
40
|
bias,
|
41
|
+
self.ce_weight,
|
39
42
|
self.ignore_index,
|
40
43
|
self.lse_square_scale,
|
41
44
|
self.label_smoothing,
|