liger-kernel 0.5.3__tar.gz → 0.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel-0.5.4/.github/workflows/intel-ci.yml +71 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/Makefile +11 -5
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/PKG-INFO +17 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/README.md +16 -2
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/data/all_benchmark_data.csv +37 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_kto_loss.py +6 -6
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_rope.py +1 -1
- liger_kernel-0.5.4/benchmark/scripts/benchmark_tvd.py +133 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/dev/modal/tests.py +1 -1
- liger_kernel-0.5.4/docs/images/post-training.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/lightning/training.py +1 -1
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/callback.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/pyproject.toml +1 -1
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/setup.py +13 -4
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel-0.5.4/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel-0.5.4/src/liger_kernel/chunked_loss/grpo_loss.py +160 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/kto_loss.py +9 -9
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/fused_linear_jsd.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/jsd.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel-0.5.4/src/liger_kernel/ops/tvd.py +207 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/utils.py +1 -2
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/__init__.py +3 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/functional.py +17 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/group_norm.py +6 -6
- liger_kernel-0.5.4/src/liger_kernel/transformers/model/olmo2.py +124 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/monkey_patch.py +171 -27
- liger_kernel-0.5.4/src/liger_kernel/transformers/tvd.py +13 -0
- liger_kernel-0.5.4/src/liger_kernel/utils.py +62 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/PKG-INFO +17 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/SOURCES.txt +18 -3
- liger_kernel-0.5.4/test/chunked_loss/test_grpo_loss.py +275 -0
- liger_kernel-0.5.4/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.4/test/convergence/bf16}/test_mini_models.py +121 -37
- {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.4/test/convergence/bf16}/test_mini_models_multimodal.py +1 -36
- {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.4/test/convergence/bf16}/test_mini_models_with_logits.py +120 -36
- liger_kernel-0.5.4/test/convergence/fp32/__init__.py +0 -0
- liger_kernel-0.5.4/test/convergence/fp32/test_mini_models.py +664 -0
- liger_kernel-0.5.4/test/convergence/fp32/test_mini_models_multimodal.py +415 -0
- liger_kernel-0.5.4/test/convergence/fp32/test_mini_models_with_logits.py +663 -0
- liger_kernel-0.5.4/test/transformers/test_flex_attention.py +291 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_layer_norm.py +20 -5
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_monkey_patch.py +54 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_qwen2vl_mrope.py +3 -2
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_rope.py +1 -1
- liger_kernel-0.5.4/test/transformers/test_tvd.py +188 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/triton/test_triton_monkey_patch.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/utils.py +18 -41
- liger_kernel-0.5.3/docs/images/post-training.png +0 -0
- liger_kernel-0.5.3/src/liger_kernel/utils.py +0 -13
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/pull_request_template.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/docs.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.gitignore +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/LICENSE +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/NOTICE +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/utils.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/dev/fmt-requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/Examples.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/Getting-Started.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/High-Level-APIs.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/acknowledgement.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/contributing.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/banner.GIF +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/compose.gif +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/e2e-memory.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/e2e-tps.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/logo-banner.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/patch.gif +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/index.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/license.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/callback.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/training.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/lightning/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/lightning/requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/train.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/mkdocs.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/setup.cfg +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/requires.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/top_level.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/conftest.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/convergence/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_transformers.py +0 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
name: Intel GPU
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
paths:
|
|
8
|
+
- "src/**"
|
|
9
|
+
- "test/**"
|
|
10
|
+
pull_request:
|
|
11
|
+
branches:
|
|
12
|
+
- main
|
|
13
|
+
paths:
|
|
14
|
+
- "src/**"
|
|
15
|
+
- "test/**"
|
|
16
|
+
schedule:
|
|
17
|
+
# Runs at 00:00 UTC daily
|
|
18
|
+
- cron: '0 0 * * *'
|
|
19
|
+
workflow_dispatch: # Enables manual trigger
|
|
20
|
+
|
|
21
|
+
concurrency:
|
|
22
|
+
# This causes it to cancel previous in-progress actions on the same PR / branch,
|
|
23
|
+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
|
24
|
+
cancel-in-progress: true
|
|
25
|
+
|
|
26
|
+
jobs:
|
|
27
|
+
checkstyle:
|
|
28
|
+
runs-on: ubuntu-latest
|
|
29
|
+
|
|
30
|
+
steps:
|
|
31
|
+
- name: Checkout code
|
|
32
|
+
uses: actions/checkout@v3
|
|
33
|
+
|
|
34
|
+
- name: Set up Python
|
|
35
|
+
uses: actions/setup-python@v3
|
|
36
|
+
with:
|
|
37
|
+
python-version: '3.10'
|
|
38
|
+
|
|
39
|
+
- name: Install dependencies
|
|
40
|
+
run: |
|
|
41
|
+
python -m pip install --upgrade pip
|
|
42
|
+
pip install -r dev/fmt-requirements.txt
|
|
43
|
+
|
|
44
|
+
- name: Run checkstyle
|
|
45
|
+
run: make checkstyle
|
|
46
|
+
|
|
47
|
+
tests:
|
|
48
|
+
runs-on: linux-max1550-gpu-8
|
|
49
|
+
needs: [checkstyle]
|
|
50
|
+
|
|
51
|
+
steps:
|
|
52
|
+
- name: Checkout code
|
|
53
|
+
uses: actions/checkout@v3
|
|
54
|
+
|
|
55
|
+
- name: Set up Python
|
|
56
|
+
uses: actions/setup-python@v3
|
|
57
|
+
with:
|
|
58
|
+
python-version: '3.10'
|
|
59
|
+
|
|
60
|
+
- name: Setup Dependencies
|
|
61
|
+
run: |
|
|
62
|
+
python -m pip install --upgrade pip
|
|
63
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/test/xpu
|
|
64
|
+
|
|
65
|
+
- name: List Python Environments
|
|
66
|
+
run: python -m pip list
|
|
67
|
+
|
|
68
|
+
- name: Run Unit Tests
|
|
69
|
+
run: |
|
|
70
|
+
make test
|
|
71
|
+
make test-convergence
|
|
@@ -9,8 +9,10 @@ test:
|
|
|
9
9
|
|
|
10
10
|
# Command to run ruff for linting and formatting code
|
|
11
11
|
checkstyle:
|
|
12
|
-
ruff check
|
|
13
|
-
ruff format .; ruff_format_status=$$?; \
|
|
12
|
+
ruff check .; ruff_check_status=$$?; \
|
|
13
|
+
ruff format --check .; ruff_format_status=$$?; \
|
|
14
|
+
ruff check . --fix; \
|
|
15
|
+
ruff format .; \
|
|
14
16
|
if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
|
|
15
17
|
exit 1; \
|
|
16
18
|
fi
|
|
@@ -18,9 +20,13 @@ checkstyle:
|
|
|
18
20
|
# Command to run pytest for convergence tests
|
|
19
21
|
# We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
|
|
20
22
|
test-convergence:
|
|
21
|
-
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
|
|
22
|
-
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
|
|
23
|
-
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
|
|
23
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
|
|
24
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py
|
|
25
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py
|
|
26
|
+
|
|
27
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py
|
|
28
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
|
|
29
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
|
|
24
30
|
|
|
25
31
|
# Command to run all benchmark scripts and update benchmarking data file
|
|
26
32
|
# By default this doesn't overwrite existing data for the same benchmark experiment
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.4
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -97,6 +97,11 @@ Dynamic: requires-dist
|
|
|
97
97
|
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
98
98
|
</a>
|
|
99
99
|
</div>
|
|
100
|
+
<div style="display: block;">
|
|
101
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
102
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
103
|
+
</a>
|
|
104
|
+
</div>
|
|
100
105
|
</td>
|
|
101
106
|
</tr>
|
|
102
107
|
</table>
|
|
@@ -123,7 +128,7 @@ Dynamic: requires-dist
|
|
|
123
128
|
|
|
124
129
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
125
130
|
|
|
126
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
131
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
127
132
|
|
|
128
133
|
## Supercharge Your Model with Liger Kernel
|
|
129
134
|
|
|
@@ -188,6 +193,11 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
188
193
|
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
|
|
189
194
|
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
190
195
|
|
|
196
|
+
```bash
|
|
197
|
+
# Need to pass the url when installing
|
|
198
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
|
199
|
+
```
|
|
200
|
+
|
|
191
201
|
### Optional Dependencies
|
|
192
202
|
|
|
193
203
|
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
@@ -305,6 +315,8 @@ loss.backward()
|
|
|
305
315
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
306
316
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
307
317
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
318
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
319
|
+
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
308
320
|
|
|
309
321
|
|
|
310
322
|
## Low-level APIs
|
|
@@ -333,6 +345,7 @@ loss.backward()
|
|
|
333
345
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
334
346
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
335
347
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
348
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
|
336
349
|
|
|
337
350
|
### Distillation Kernels
|
|
338
351
|
|
|
@@ -341,6 +354,7 @@ loss.backward()
|
|
|
341
354
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
342
355
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
343
356
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
357
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
|
344
358
|
|
|
345
359
|
### Experimental Kernels
|
|
346
360
|
|
|
@@ -372,7 +386,7 @@ loss.backward()
|
|
|
372
386
|
|
|
373
387
|
- For issues, create a Github ticket in this repository
|
|
374
388
|
- For open discussion, join [our discord channel](https://discord.gg/gpumode)
|
|
375
|
-
- For formal collaboration, send an email to
|
|
389
|
+
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
376
390
|
|
|
377
391
|
## Cite this work
|
|
378
392
|
|
|
@@ -47,6 +47,11 @@
|
|
|
47
47
|
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
48
48
|
</a>
|
|
49
49
|
</div>
|
|
50
|
+
<div style="display: block;">
|
|
51
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
52
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
53
|
+
</a>
|
|
54
|
+
</div>
|
|
50
55
|
</td>
|
|
51
56
|
</tr>
|
|
52
57
|
</table>
|
|
@@ -73,7 +78,7 @@
|
|
|
73
78
|
|
|
74
79
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
75
80
|
|
|
76
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
81
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
77
82
|
|
|
78
83
|
## Supercharge Your Model with Liger Kernel
|
|
79
84
|
|
|
@@ -138,6 +143,11 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
138
143
|
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
|
|
139
144
|
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
140
145
|
|
|
146
|
+
```bash
|
|
147
|
+
# Need to pass the url when installing
|
|
148
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
|
149
|
+
```
|
|
150
|
+
|
|
141
151
|
### Optional Dependencies
|
|
142
152
|
|
|
143
153
|
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
@@ -255,6 +265,8 @@ loss.backward()
|
|
|
255
265
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
256
266
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
257
267
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
268
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
269
|
+
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
258
270
|
|
|
259
271
|
|
|
260
272
|
## Low-level APIs
|
|
@@ -283,6 +295,7 @@ loss.backward()
|
|
|
283
295
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
284
296
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
285
297
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
298
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
|
286
299
|
|
|
287
300
|
### Distillation Kernels
|
|
288
301
|
|
|
@@ -291,6 +304,7 @@ loss.backward()
|
|
|
291
304
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
292
305
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
293
306
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
307
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
|
294
308
|
|
|
295
309
|
### Experimental Kernels
|
|
296
310
|
|
|
@@ -322,7 +336,7 @@ loss.backward()
|
|
|
322
336
|
|
|
323
337
|
- For issues, create a Github ticket in this repository
|
|
324
338
|
- For open discussion, join [our discord channel](https://discord.gg/gpumode)
|
|
325
|
-
- For formal collaboration, send an email to
|
|
339
|
+
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
326
340
|
|
|
327
341
|
## Cite this work
|
|
328
342
|
|
|
@@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
|
|
|
505
505
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
|
506
506
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
|
507
507
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
|
508
|
+
tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
509
|
+
tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
510
|
+
tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
511
|
+
tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
512
|
+
tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
513
|
+
tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
514
|
+
tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
515
|
+
tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
516
|
+
tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
517
|
+
tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
518
|
+
tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
519
|
+
tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
520
|
+
tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
521
|
+
tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
522
|
+
tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
523
|
+
tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
524
|
+
tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
525
|
+
tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
526
|
+
tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
527
|
+
tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
528
|
+
tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
529
|
+
tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
530
|
+
tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
531
|
+
tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
532
|
+
tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
533
|
+
tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
534
|
+
tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
535
|
+
tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
536
|
+
tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
537
|
+
tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
538
|
+
tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
539
|
+
tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
540
|
+
tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
541
|
+
tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
542
|
+
tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
543
|
+
tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
508
544
|
group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
|
509
545
|
group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
|
510
546
|
group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
|
@@ -769,3 +805,4 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
|
|
|
769
805
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
770
806
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
771
807
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
808
|
+
|
|
@@ -103,8 +103,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
103
103
|
H=H,
|
|
104
104
|
V=V,
|
|
105
105
|
dtype=dtype,
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
use_bias=bias,
|
|
107
|
+
use_ref_bias=bias,
|
|
108
108
|
ignore_index=ignore_index,
|
|
109
109
|
beta=beta,
|
|
110
110
|
).to(device)
|
|
@@ -113,8 +113,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
113
113
|
H=H,
|
|
114
114
|
V=V,
|
|
115
115
|
dtype=dtype,
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
use_bias=bias,
|
|
117
|
+
use_ref_bias=bias,
|
|
118
118
|
ignore_index=ignore_index,
|
|
119
119
|
beta=beta,
|
|
120
120
|
).to(device)
|
|
@@ -189,7 +189,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
189
189
|
dtype=dtype,
|
|
190
190
|
beta=beta,
|
|
191
191
|
ignore_index=ignore_index,
|
|
192
|
-
|
|
192
|
+
use_bias=bias,
|
|
193
193
|
).to(device)
|
|
194
194
|
liger_kto_loss = LigerLMHeadKTO(
|
|
195
195
|
H=H,
|
|
@@ -197,7 +197,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
197
197
|
dtype=dtype,
|
|
198
198
|
beta=beta,
|
|
199
199
|
ignore_index=ignore_index,
|
|
200
|
-
|
|
200
|
+
use_bias=bias,
|
|
201
201
|
).to(device)
|
|
202
202
|
|
|
203
203
|
# Input shape: [B, T, H]
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
|
-
from test.utils import transformers_version_dispatch
|
|
5
4
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
6
5
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
7
6
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
@@ -14,6 +13,7 @@ from utils import run_benchmarks
|
|
|
14
13
|
|
|
15
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
16
15
|
from liger_kernel.utils import infer_device
|
|
16
|
+
from liger_kernel.utils import transformers_version_dispatch
|
|
17
17
|
|
|
18
18
|
device = infer_device()
|
|
19
19
|
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
|
|
4
|
+
from utils import QUANTILES
|
|
5
|
+
from utils import SingleBenchmarkRunInput
|
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
|
7
|
+
from utils import _test_memory
|
|
8
|
+
from utils import parse_benchmark_script_args
|
|
9
|
+
from utils import run_benchmarks
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.tvd import LigerTVDLoss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TorchTVDLoss(torch.nn.Module):
|
|
15
|
+
def __init__(self, reduction="batchmean"):
|
|
16
|
+
super(TorchTVDLoss, self).__init__()
|
|
17
|
+
self.reduction = reduction
|
|
18
|
+
|
|
19
|
+
def forward(self, p, q):
|
|
20
|
+
tvd = torch.abs(p - q) / 2.0
|
|
21
|
+
if self.reduction == "mean":
|
|
22
|
+
return torch.sum(tvd) / (p.size(0) * p.size(1))
|
|
23
|
+
elif self.reduction == "sum":
|
|
24
|
+
return torch.sum(tvd)
|
|
25
|
+
elif self.reduction == "none":
|
|
26
|
+
return tvd
|
|
27
|
+
elif self.reduction == "batchmean":
|
|
28
|
+
return torch.sum(tvd) / p.size(0)
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError("Invalid reduction type.")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
S, E = 12, 18
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
37
|
+
reduction = "batchmean"
|
|
38
|
+
V = input.x
|
|
39
|
+
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
|
|
40
|
+
torch_tvd = TorchTVDLoss(reduction=reduction)
|
|
41
|
+
liger_tvd = LigerTVDLoss(reduction=reduction)
|
|
42
|
+
|
|
43
|
+
_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
|
|
44
|
+
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
|
|
45
|
+
|
|
46
|
+
def fwd():
|
|
47
|
+
if input.kernel_provider == "liger":
|
|
48
|
+
return liger_tvd(_input, target)
|
|
49
|
+
else:
|
|
50
|
+
return torch_tvd(_input, target)
|
|
51
|
+
|
|
52
|
+
if input.kernel_operation_mode == "forward":
|
|
53
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
|
|
54
|
+
elif input.kernel_operation_mode == "backward":
|
|
55
|
+
y = fwd()
|
|
56
|
+
|
|
57
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
|
58
|
+
lambda: y.backward(retain_graph=True),
|
|
59
|
+
quantiles=QUANTILES,
|
|
60
|
+
grad_to_none=[_input],
|
|
61
|
+
rep=100,
|
|
62
|
+
)
|
|
63
|
+
elif input.kernel_operation_mode == "full":
|
|
64
|
+
|
|
65
|
+
def full():
|
|
66
|
+
y = fwd()
|
|
67
|
+
y.backward(retain_graph=True)
|
|
68
|
+
|
|
69
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
|
|
70
|
+
return SingleBenchmarkRunOutput(
|
|
71
|
+
y_20=ms_20,
|
|
72
|
+
y_50=ms_50,
|
|
73
|
+
y_80=ms_80,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
78
|
+
reduction = "batchmean"
|
|
79
|
+
torch_tvd = TorchTVDLoss(reduction=reduction)
|
|
80
|
+
liger_tvd = LigerTVDLoss(reduction=reduction)
|
|
81
|
+
|
|
82
|
+
V = input.x
|
|
83
|
+
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
|
|
84
|
+
|
|
85
|
+
_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
|
|
86
|
+
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
|
|
87
|
+
|
|
88
|
+
def fwd():
|
|
89
|
+
if input.kernel_provider == "liger":
|
|
90
|
+
return liger_tvd(_input, target)
|
|
91
|
+
else:
|
|
92
|
+
return torch_tvd(_input, target)
|
|
93
|
+
|
|
94
|
+
def full():
|
|
95
|
+
y = fwd()
|
|
96
|
+
y.backward(retain_graph=True)
|
|
97
|
+
|
|
98
|
+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
|
|
99
|
+
|
|
100
|
+
return SingleBenchmarkRunOutput(
|
|
101
|
+
y_20=mem_20,
|
|
102
|
+
y_50=mem_50,
|
|
103
|
+
y_80=mem_80,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
if __name__ == "__main__":
|
|
108
|
+
args = parse_benchmark_script_args()
|
|
109
|
+
common_args = {
|
|
110
|
+
"kernel_name": "tvd",
|
|
111
|
+
"x_name": "V",
|
|
112
|
+
"x_label": "vocab size",
|
|
113
|
+
"x_values": [2**i for i in range(12, 18)],
|
|
114
|
+
"kernel_providers": ["liger", "torch"],
|
|
115
|
+
"extra_benchmark_configs": [{"B": 8, "T": 2048}],
|
|
116
|
+
"overwrite": args.overwrite,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
run_benchmarks(
|
|
120
|
+
bench_test_fn=bench_memory_tvd,
|
|
121
|
+
kernel_operation_modes=["full"],
|
|
122
|
+
metric_name="memory",
|
|
123
|
+
metric_unit="MB",
|
|
124
|
+
**common_args,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
run_benchmarks(
|
|
128
|
+
bench_test_fn=bench_speed_tvd,
|
|
129
|
+
kernel_operation_modes=["forward", "full"],
|
|
130
|
+
metric_name="speed",
|
|
131
|
+
metric_unit="ms",
|
|
132
|
+
**common_args,
|
|
133
|
+
)
|
|
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
|
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
|
|
18
18
|
def liger_tests():
|
|
19
19
|
import subprocess
|
|
20
20
|
|
|
Binary file
|
|
@@ -158,7 +158,7 @@ class DataModule(pl.LightningDataModule):
|
|
|
158
158
|
for i in range(len(example["question"])):
|
|
159
159
|
choices = ""
|
|
160
160
|
for j in range(len(example["choices"][i])):
|
|
161
|
-
choices += f"{j+1}. {example['choices'][i][j]}; "
|
|
161
|
+
choices += f"{j + 1}. {example['choices'][i][j]}; "
|
|
162
162
|
s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
|
|
163
163
|
s += f"{QUESTION}{example['question'][i]} "
|
|
164
164
|
s += f"{CHOICES}{choices} "
|
|
@@ -352,9 +352,9 @@ class EfficiencyCallback(transformers.TrainerCallback):
|
|
|
352
352
|
else:
|
|
353
353
|
return world_size
|
|
354
354
|
|
|
355
|
-
assert (
|
|
356
|
-
|
|
357
|
-
)
|
|
355
|
+
assert world_size != 0, (
|
|
356
|
+
"WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
|
|
357
|
+
)
|
|
358
358
|
|
|
359
359
|
# TODO: add deepspeed support
|
|
360
360
|
return world_size
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel"
|
|
7
|
-
version = "0.5.
|
|
7
|
+
version = "0.5.4"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -21,6 +21,11 @@ def get_default_dependencies():
|
|
|
21
21
|
"torch>=2.6.0.dev",
|
|
22
22
|
"triton>=3.0.0",
|
|
23
23
|
]
|
|
24
|
+
elif platform == "xpu":
|
|
25
|
+
return [
|
|
26
|
+
"torch>=2.6.0",
|
|
27
|
+
"pytorch-triton-xpu>=3.2.0",
|
|
28
|
+
]
|
|
24
29
|
|
|
25
30
|
|
|
26
31
|
def get_optional_dependencies():
|
|
@@ -43,8 +48,7 @@ def get_optional_dependencies():
|
|
|
43
48
|
}
|
|
44
49
|
|
|
45
50
|
|
|
46
|
-
|
|
47
|
-
def get_platform() -> Literal["cuda", "rocm", "cpu"]:
|
|
51
|
+
def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu"]:
|
|
48
52
|
"""
|
|
49
53
|
Detect whether the system has NVIDIA or AMD GPU without torch dependency.
|
|
50
54
|
"""
|
|
@@ -60,8 +64,13 @@ def get_platform() -> Literal["cuda", "rocm", "cpu"]:
|
|
|
60
64
|
print("ROCm GPU detected")
|
|
61
65
|
return "rocm"
|
|
62
66
|
except (subprocess.SubprocessError, FileNotFoundError):
|
|
63
|
-
|
|
64
|
-
|
|
67
|
+
try:
|
|
68
|
+
subprocess.run(["xpu-smi"], check=True)
|
|
69
|
+
print("Intel GPU detected")
|
|
70
|
+
return "xpu"
|
|
71
|
+
except (subprocess.SubprocessError, FileNotFoundError):
|
|
72
|
+
print("No GPU detected")
|
|
73
|
+
return "cpu"
|
|
65
74
|
|
|
66
75
|
|
|
67
76
|
setup(
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
|
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|