liger-kernel 0.5.9__tar.gz → 0.5.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/docs.yml +1 -1
- liger_kernel-0.5.10/.idea/workspace.xml +79 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/PKG-INFO +34 -20
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/README.md +33 -19
- liger_kernel-0.5.10/benchmark/README.md +48 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/benchmarks_visualizer.py +35 -9
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/data/all_benchmark_data.csv +72 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_dyt.py +37 -34
- liger_kernel-0.5.10/benchmark/scripts/benchmark_sparsemax.py +172 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/dev/modal/tests.py +2 -2
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/dev/modal/tests_bwd.py +2 -2
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/Low-Level-APIs.md +9 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/training_multimodal.py +1 -1
- liger_kernel-0.5.10/examples/medusa/requirements.txt +3 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/scripts/llama3_8b_medusa.sh +2 -5
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/train.py +37 -39
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/pyproject.toml +1 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/setup.py +23 -3
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/dpo_loss.py +1 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel-0.5.10/src/liger_kernel/ops/dyt.py +159 -0
- liger_kernel-0.5.10/src/liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel-0.5.10/src/liger_kernel/ops/sparsemax.py +167 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/__init__.py +5 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/dyt.py +5 -3
- liger_kernel-0.5.10/src/liger_kernel/transformers/fsdp.py +55 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/functional.py +8 -0
- liger_kernel-0.5.10/src/liger_kernel/transformers/grpo_loss.py +98 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/gemma.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/gemma2.py +0 -6
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/gemma3.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/glm4.py +0 -6
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/llama.py +56 -11
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/llava.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/mistral.py +0 -6
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/mixtral.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/mllama.py +0 -7
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/olmo2.py +0 -6
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/paligemma.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/phi3.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen2.py +0 -8
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -6
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen3.py +0 -6
- liger_kernel-0.5.10/src/liger_kernel/transformers/model/qwen3_moe.py +128 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/monkey_patch.py +122 -13
- liger_kernel-0.5.10/src/liger_kernel/transformers/sparsemax.py +16 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/swiglu.py +21 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/PKG-INFO +34 -20
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/SOURCES.txt +10 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_dpo_loss.py +2 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/test_mini_models.py +58 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/test_mini_models_multimodal.py +0 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/test_mini_models_with_logits.py +58 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/test_mini_models.py +55 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/test_mini_models_multimodal.py +0 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/test_mini_models_with_logits.py +55 -1
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_dyt.py +40 -20
- liger_kernel-0.5.10/test/transformers/test_grpo_loss.py +190 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_monkey_patch.py +40 -0
- liger_kernel-0.5.10/test/transformers/test_sparsemax.py +111 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/utils.py +12 -0
- liger_kernel-0.5.9/benchmark/README.md +0 -30
- liger_kernel-0.5.9/examples/medusa/requirements.txt +0 -3
- liger_kernel-0.5.9/src/liger_kernel/ops/dyt.py +0 -225
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/pull_request_template.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.gitignore +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/LICENSE +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/Makefile +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/NOTICE +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/utils.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/dev/fmt-requirements.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/Examples.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/Getting-Started.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/High-Level-APIs.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/acknowledgement.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/contributing.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/banner.GIF +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/compose.gif +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/e2e-memory.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/e2e-tps.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/logo-banner.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/patch.gif +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/post-training.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/index.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/license.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/README.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/callback.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/training.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/lightning/README.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/lightning/requirements.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/lightning/training.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/README.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/callback.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/mkdocs.yml +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/setup.cfg +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/gema3_rms.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/requires.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/top_level.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/conftest.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_embedding.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_geglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_jsd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_rope.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_transformers.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_tvd.py +0 -0
- {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
<?xml version="1.0" encoding="UTF-8"?>
|
|
2
|
+
<project version="4">
|
|
3
|
+
<component name="AutoImportSettings">
|
|
4
|
+
<option name="autoReloadType" value="SELECTIVE" />
|
|
5
|
+
</component>
|
|
6
|
+
<component name="ChangeListManager">
|
|
7
|
+
<list default="true" id="d7400753-faa8-4997-a53e-65fd3a6e6146" name="Changes" comment="Reference Unsloth in header" />
|
|
8
|
+
<option name="SHOW_DIALOG" value="false" />
|
|
9
|
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
10
|
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
|
11
|
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
|
12
|
+
</component>
|
|
13
|
+
<component name="Git.Settings">
|
|
14
|
+
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
|
15
|
+
<map>
|
|
16
|
+
<entry key="$PROJECT_DIR$" value="main" />
|
|
17
|
+
</map>
|
|
18
|
+
</option>
|
|
19
|
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
|
20
|
+
</component>
|
|
21
|
+
<component name="GitHubPullRequestSearchHistory"><![CDATA[{
|
|
22
|
+
"lastFilter": {
|
|
23
|
+
"state": "OPEN",
|
|
24
|
+
"assignee": "momochen"
|
|
25
|
+
}
|
|
26
|
+
}]]></component>
|
|
27
|
+
<component name="GithubPullRequestsUISettings"><![CDATA[{
|
|
28
|
+
"selectedUrlAndAccountId": {
|
|
29
|
+
"url": "https://github.com/momochen/Liger-Kernel",
|
|
30
|
+
"accountId": "639f3e12-86db-4b12-a409-51cc017415fb"
|
|
31
|
+
}
|
|
32
|
+
}]]></component>
|
|
33
|
+
<component name="ProjectColorInfo"><![CDATA[{
|
|
34
|
+
"associatedIndex": 5
|
|
35
|
+
}]]></component>
|
|
36
|
+
<component name="ProjectId" id="2lfyDxCjSnvFrbllYmf9VBSCcMx" />
|
|
37
|
+
<component name="ProjectViewState">
|
|
38
|
+
<option name="hideEmptyMiddlePackages" value="true" />
|
|
39
|
+
<option name="showLibraryContents" value="true" />
|
|
40
|
+
</component>
|
|
41
|
+
<component name="PropertiesComponent"><![CDATA[{
|
|
42
|
+
"keyToString": {
|
|
43
|
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
|
44
|
+
"git-widget-placeholder": "ref__unsloth",
|
|
45
|
+
"last_opened_file_path": "/Users/ychen/workspace/github/Liger-Kernel"
|
|
46
|
+
}
|
|
47
|
+
}]]></component>
|
|
48
|
+
<component name="SharedIndexes">
|
|
49
|
+
<attachedChunks>
|
|
50
|
+
<set>
|
|
51
|
+
<option value="bundled-python-sdk-975db3bf15a3-31b6be0877a2-com.jetbrains.pycharm.community.sharedIndexes.bundled-PC-241.18034.82" />
|
|
52
|
+
</set>
|
|
53
|
+
</attachedChunks>
|
|
54
|
+
</component>
|
|
55
|
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
|
56
|
+
<component name="TaskManager">
|
|
57
|
+
<task active="true" id="Default" summary="Default task">
|
|
58
|
+
<changelist id="d7400753-faa8-4997-a53e-65fd3a6e6146" name="Changes" comment="" />
|
|
59
|
+
<created>1725585310555</created>
|
|
60
|
+
<option name="number" value="Default" />
|
|
61
|
+
<option name="presentableId" value="Default" />
|
|
62
|
+
<updated>1725585310555</updated>
|
|
63
|
+
</task>
|
|
64
|
+
<task id="LOCAL-00001" summary="Reference Unsloth in header">
|
|
65
|
+
<option name="closed" value="true" />
|
|
66
|
+
<created>1725585434299</created>
|
|
67
|
+
<option name="number" value="00001" />
|
|
68
|
+
<option name="presentableId" value="LOCAL-00001" />
|
|
69
|
+
<option name="project" value="LOCAL" />
|
|
70
|
+
<updated>1725585434299</updated>
|
|
71
|
+
</task>
|
|
72
|
+
<option name="localTasksCounter" value="2" />
|
|
73
|
+
<servers />
|
|
74
|
+
</component>
|
|
75
|
+
<component name="VcsManagerConfiguration">
|
|
76
|
+
<MESSAGE value="Reference Unsloth in header" />
|
|
77
|
+
<option name="LAST_COMMIT_MESSAGE" value="Reference Unsloth in header" />
|
|
78
|
+
</component>
|
|
79
|
+
</project>
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.10
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -59,7 +59,6 @@ Dynamic: requires-dist
|
|
|
59
59
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
60
60
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
61
61
|
<th style="padding: 10px;">Discord</th>
|
|
62
|
-
<th style="padding: 10px;">Build</th>
|
|
63
62
|
</tr>
|
|
64
63
|
<tr>
|
|
65
64
|
<td style="padding: 10px;">
|
|
@@ -87,23 +86,6 @@ Dynamic: requires-dist
|
|
|
87
86
|
<img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
|
|
88
87
|
</a>
|
|
89
88
|
</td>
|
|
90
|
-
<td style="padding: 10px;">
|
|
91
|
-
<div style="display: block;">
|
|
92
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
93
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
94
|
-
</a>
|
|
95
|
-
</div>
|
|
96
|
-
<div style="display: block;">
|
|
97
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
98
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
99
|
-
</a>
|
|
100
|
-
</div>
|
|
101
|
-
<div style="display: block;">
|
|
102
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
103
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
104
|
-
</a>
|
|
105
|
-
</div>
|
|
106
|
-
</td>
|
|
107
89
|
</tr>
|
|
108
90
|
</table>
|
|
109
91
|
|
|
@@ -321,6 +303,7 @@ loss.backward()
|
|
|
321
303
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
322
304
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
323
305
|
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
306
|
+
| Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
324
307
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
325
308
|
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
326
309
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -342,7 +325,8 @@ loss.backward()
|
|
|
342
325
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
343
326
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
344
327
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
345
|
-
| Fused Linear CrossEntropy
|
|
328
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
329
|
+
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
|
|
346
330
|
|
|
347
331
|
|
|
348
332
|
### Alignment Kernels
|
|
@@ -390,6 +374,36 @@ loss.backward()
|
|
|
390
374
|
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
391
375
|
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
392
376
|
|
|
377
|
+
|
|
378
|
+
## CI status
|
|
379
|
+
|
|
380
|
+
<table style="width: 100%; text-align: center; border-collapse: collapse;">
|
|
381
|
+
<tr>
|
|
382
|
+
<th style="padding: 10px;">Build</th>
|
|
383
|
+
</tr>
|
|
384
|
+
<tr>
|
|
385
|
+
<td style="padding: 10px;">
|
|
386
|
+
<div style="display: block;">
|
|
387
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
388
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
389
|
+
</a>
|
|
390
|
+
</div>
|
|
391
|
+
<div style="display: block;">
|
|
392
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
393
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
394
|
+
</a>
|
|
395
|
+
</div>
|
|
396
|
+
<div style="display: block;">
|
|
397
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
398
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
399
|
+
</a>
|
|
400
|
+
</div>
|
|
401
|
+
</td>
|
|
402
|
+
</tr>
|
|
403
|
+
</table>
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
|
|
393
407
|
## Contact
|
|
394
408
|
|
|
395
409
|
- For issues, create a Github ticket in this repository
|
|
@@ -8,7 +8,6 @@
|
|
|
8
8
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
9
9
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
10
10
|
<th style="padding: 10px;">Discord</th>
|
|
11
|
-
<th style="padding: 10px;">Build</th>
|
|
12
11
|
</tr>
|
|
13
12
|
<tr>
|
|
14
13
|
<td style="padding: 10px;">
|
|
@@ -36,23 +35,6 @@
|
|
|
36
35
|
<img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
|
|
37
36
|
</a>
|
|
38
37
|
</td>
|
|
39
|
-
<td style="padding: 10px;">
|
|
40
|
-
<div style="display: block;">
|
|
41
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
42
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
43
|
-
</a>
|
|
44
|
-
</div>
|
|
45
|
-
<div style="display: block;">
|
|
46
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
47
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
48
|
-
</a>
|
|
49
|
-
</div>
|
|
50
|
-
<div style="display: block;">
|
|
51
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
52
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
53
|
-
</a>
|
|
54
|
-
</div>
|
|
55
|
-
</td>
|
|
56
38
|
</tr>
|
|
57
39
|
</table>
|
|
58
40
|
|
|
@@ -270,6 +252,7 @@ loss.backward()
|
|
|
270
252
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
271
253
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
272
254
|
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
255
|
+
| Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
273
256
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
274
257
|
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
275
258
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -291,7 +274,8 @@ loss.backward()
|
|
|
291
274
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
292
275
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
293
276
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
294
|
-
| Fused Linear CrossEntropy
|
|
277
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
278
|
+
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
|
|
295
279
|
|
|
296
280
|
|
|
297
281
|
### Alignment Kernels
|
|
@@ -339,6 +323,36 @@ loss.backward()
|
|
|
339
323
|
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
340
324
|
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
341
325
|
|
|
326
|
+
|
|
327
|
+
## CI status
|
|
328
|
+
|
|
329
|
+
<table style="width: 100%; text-align: center; border-collapse: collapse;">
|
|
330
|
+
<tr>
|
|
331
|
+
<th style="padding: 10px;">Build</th>
|
|
332
|
+
</tr>
|
|
333
|
+
<tr>
|
|
334
|
+
<td style="padding: 10px;">
|
|
335
|
+
<div style="display: block;">
|
|
336
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
337
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
338
|
+
</a>
|
|
339
|
+
</div>
|
|
340
|
+
<div style="display: block;">
|
|
341
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
342
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
343
|
+
</a>
|
|
344
|
+
</div>
|
|
345
|
+
<div style="display: block;">
|
|
346
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
347
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
348
|
+
</a>
|
|
349
|
+
</div>
|
|
350
|
+
</td>
|
|
351
|
+
</tr>
|
|
352
|
+
</table>
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
|
|
342
356
|
## Contact
|
|
343
357
|
|
|
344
358
|
- For issues, create a Github ticket in this repository
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
## Benchmarking Liger Kernels
|
|
2
|
+
|
|
3
|
+
Follow these steps to benchmark and visualize kernel performance:
|
|
4
|
+
|
|
5
|
+
1. Create a benchmark script
|
|
6
|
+
- Add your script under `benchmark/scripts/`
|
|
7
|
+
- Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)
|
|
8
|
+
|
|
9
|
+
2. Run the benchmark
|
|
10
|
+
- Results will be saved to `benchmark/data/all_benchmark_data.csv`
|
|
11
|
+
|
|
12
|
+
Example: Benchmarking KTO Loss
|
|
13
|
+
```bash
|
|
14
|
+
cd benchmark
|
|
15
|
+
python scripts/benchmark_kto_loss.py
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
3. Visualize results
|
|
19
|
+
- Use the visualization script with optional modes:
|
|
20
|
+
|
|
21
|
+
* To target specific mode(s), pass `--kernel-operation-mode` one or more values.
|
|
22
|
+
* If you omit `--kernel-operation-mode`, the script will:
|
|
23
|
+
- For `speed` metrics: generate plots for all available modes (forward/backward/full).
|
|
24
|
+
- For `memory` metrics: generate only the `full` plot.
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
1. Specific modes (speed):
|
|
28
|
+
```bash
|
|
29
|
+
python benchmarks_visualizer.py \
|
|
30
|
+
--kernel-name kto_loss \
|
|
31
|
+
--metric-name speed \
|
|
32
|
+
--kernel-operation-mode forward backward
|
|
33
|
+
```
|
|
34
|
+
2. All modes (speed):
|
|
35
|
+
```bash
|
|
36
|
+
python benchmarks_visualizer.py \
|
|
37
|
+
--kernel-name kto_loss \
|
|
38
|
+
--metric-name speed
|
|
39
|
+
```
|
|
40
|
+
3. Memory (always full):
|
|
41
|
+
```bash
|
|
42
|
+
python benchmarks_visualizer.py \
|
|
43
|
+
--kernel-name kto_loss \
|
|
44
|
+
--metric-name memory
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
4. View results
|
|
48
|
+
- Generated plots will be saved in `benchmark/visualizations/`
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
import sys
|
|
3
4
|
|
|
4
5
|
from argparse import ArgumentParser
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -50,8 +51,9 @@ def parse_args() -> VisualizationsConfig:
|
|
|
50
51
|
parser.add_argument(
|
|
51
52
|
"--kernel-operation-mode",
|
|
52
53
|
type=str,
|
|
53
|
-
|
|
54
|
-
|
|
54
|
+
nargs="*",
|
|
55
|
+
default=None,
|
|
56
|
+
help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.",
|
|
55
57
|
)
|
|
56
58
|
parser.add_argument("--display", action="store_true", help="Display the visualization")
|
|
57
59
|
parser.add_argument(
|
|
@@ -61,8 +63,7 @@ def parse_args() -> VisualizationsConfig:
|
|
|
61
63
|
)
|
|
62
64
|
|
|
63
65
|
args = parser.parse_args()
|
|
64
|
-
|
|
65
|
-
return VisualizationsConfig(**dict(args._get_kwargs()))
|
|
66
|
+
return args
|
|
66
67
|
|
|
67
68
|
|
|
68
69
|
def load_data(config: VisualizationsConfig) -> pd.DataFrame:
|
|
@@ -123,7 +124,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
|
|
|
123
124
|
lines = ax.get_lines()
|
|
124
125
|
colors = [line.get_color() for line in lines]
|
|
125
126
|
|
|
126
|
-
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors
|
|
127
|
+
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
|
|
127
128
|
# for i, row in group_data.iterrows():
|
|
128
129
|
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
|
|
129
130
|
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
|
|
@@ -142,7 +143,10 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
|
|
|
142
143
|
plt.ylabel(ylabel)
|
|
143
144
|
plt.tight_layout()
|
|
144
145
|
|
|
145
|
-
out_path = os.path.join(
|
|
146
|
+
out_path = os.path.join(
|
|
147
|
+
VISUALIZATIONS_PATH,
|
|
148
|
+
f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}.png",
|
|
149
|
+
)
|
|
146
150
|
|
|
147
151
|
if config.display:
|
|
148
152
|
plt.show()
|
|
@@ -155,9 +159,31 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
|
|
|
155
159
|
|
|
156
160
|
|
|
157
161
|
def main():
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
162
|
+
args = parse_args()
|
|
163
|
+
all_df = pd.read_csv(DATA_PATH)
|
|
164
|
+
all_df["extra_benchmark_config"] = all_df["extra_benchmark_config_str"].apply(json.loads)
|
|
165
|
+
|
|
166
|
+
if args.metric_name == "memory":
|
|
167
|
+
modes = ["full"]
|
|
168
|
+
elif args.kernel_operation_mode:
|
|
169
|
+
modes = args.kernel_operation_mode
|
|
170
|
+
else:
|
|
171
|
+
filtered = all_df[(all_df["kernel_name"] == args.kernel_name) & (all_df["metric_name"] == args.metric_name)]
|
|
172
|
+
modes = filtered["kernel_operation_mode"].unique().tolist()
|
|
173
|
+
if not modes:
|
|
174
|
+
print(f"No data found for kernel '{args.kernel_name}' and metric '{args.metric_name}'.", file=sys.stderr)
|
|
175
|
+
sys.exit(1)
|
|
176
|
+
|
|
177
|
+
for mode in modes:
|
|
178
|
+
config = VisualizationsConfig(
|
|
179
|
+
kernel_name=args.kernel_name,
|
|
180
|
+
metric_name=args.metric_name,
|
|
181
|
+
kernel_operation_mode=mode,
|
|
182
|
+
display=args.display,
|
|
183
|
+
overwrite=args.overwrite,
|
|
184
|
+
)
|
|
185
|
+
df = load_data(config)
|
|
186
|
+
plot_data(df, config)
|
|
161
187
|
|
|
162
188
|
|
|
163
189
|
if __name__ == "__main__":
|
|
@@ -805,3 +805,75 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.265
|
|
|
805
805
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
806
806
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
807
807
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
808
|
+
sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
|
809
|
+
sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
|
810
|
+
sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
|
811
|
+
sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
|
812
|
+
sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
|
813
|
+
sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
|
814
|
+
sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
|
815
|
+
sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
|
816
|
+
sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
|
817
|
+
sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
|
818
|
+
sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
|
819
|
+
sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
|
820
|
+
sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
|
821
|
+
sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
|
822
|
+
sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
|
823
|
+
sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
|
824
|
+
sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
|
825
|
+
sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
|
826
|
+
sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
|
827
|
+
sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
|
828
|
+
sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
|
829
|
+
sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
|
830
|
+
sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
|
831
|
+
sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
|
832
|
+
sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
|
833
|
+
sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
|
834
|
+
sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
|
835
|
+
sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
|
836
|
+
sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
|
837
|
+
sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
|
838
|
+
sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
839
|
+
sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
840
|
+
sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
841
|
+
sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
842
|
+
sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
843
|
+
sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
844
|
+
sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
845
|
+
sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
846
|
+
sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
847
|
+
sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
848
|
+
sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
849
|
+
sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
|
850
|
+
sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
|
851
|
+
sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
|
852
|
+
sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
|
853
|
+
sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
|
854
|
+
sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
|
855
|
+
sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
|
856
|
+
sparsemax,liger,forward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
857
|
+
sparsemax,liger,forward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
858
|
+
sparsemax,liger,forward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
859
|
+
sparsemax,liger,forward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
860
|
+
sparsemax,liger,forward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
861
|
+
sparsemax,liger,forward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
862
|
+
sparsemax,torch,forward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
863
|
+
sparsemax,torch,forward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
864
|
+
sparsemax,torch,forward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
865
|
+
sparsemax,torch,forward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
866
|
+
sparsemax,torch,forward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
867
|
+
sparsemax,torch,forward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
|
868
|
+
sparsemax,liger,backward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
|
869
|
+
sparsemax,liger,backward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
|
870
|
+
sparsemax,liger,backward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
|
871
|
+
sparsemax,liger,backward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
|
872
|
+
sparsemax,liger,backward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
|
873
|
+
sparsemax,liger,backward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
|
874
|
+
sparsemax,torch,backward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
|
875
|
+
sparsemax,torch,backward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
|
876
|
+
sparsemax,torch,backward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
|
877
|
+
sparsemax,torch,backward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
|
878
|
+
sparsemax,torch,backward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
|
879
|
+
sparsemax,torch,backward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
|
@@ -22,17 +22,18 @@ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
|
22
22
|
from test.transformers.test_dyt import LigerDyT
|
|
23
23
|
from test.transformers.test_dyt import TorchDyT
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
hidden_size = input.x
|
|
26
26
|
provider = input.kernel_provider
|
|
27
27
|
mode = input.kernel_operation_mode
|
|
28
28
|
extra_benchmark_config = input.extra_benchmark_config
|
|
29
|
-
|
|
29
|
+
BT = extra_benchmark_config["BT"]
|
|
30
|
+
beta = extra_benchmark_config["beta"]
|
|
30
31
|
dtype = extra_benchmark_config["dtype"]
|
|
31
32
|
|
|
32
33
|
x_shape = (BT, hidden_size)
|
|
33
|
-
torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
|
|
34
|
-
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
|
|
35
|
-
triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
|
|
34
|
+
torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
|
|
35
|
+
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
|
|
36
|
+
triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
|
|
36
37
|
|
|
37
38
|
x = torch.randn(x_shape, dtype=dtype, device=device)
|
|
38
39
|
dy = torch.randn_like(x)
|
|
@@ -75,16 +76,17 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
|
|
|
75
76
|
from test.transformers.test_dyt import LigerDyT
|
|
76
77
|
from test.transformers.test_dyt import TorchDyT
|
|
77
78
|
|
|
78
|
-
|
|
79
|
+
hidden_size = input.x
|
|
79
80
|
provider = input.kernel_provider
|
|
80
81
|
extra_benchmark_config = input.extra_benchmark_config
|
|
81
|
-
|
|
82
|
+
BT = extra_benchmark_config["BT"]
|
|
83
|
+
beta = extra_benchmark_config["beta"]
|
|
82
84
|
dtype = extra_benchmark_config["dtype"]
|
|
83
85
|
|
|
84
86
|
x_shape = (BT, hidden_size)
|
|
85
|
-
torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
|
|
86
|
-
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
|
|
87
|
-
triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
|
|
87
|
+
torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
|
|
88
|
+
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
|
|
89
|
+
triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
|
|
88
90
|
|
|
89
91
|
x = torch.randn(x_shape, dtype=dtype, device=device)
|
|
90
92
|
dy = torch.randn_like(x)
|
|
@@ -113,27 +115,28 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
|
|
|
113
115
|
if __name__ == "__main__":
|
|
114
116
|
args = parse_benchmark_script_args()
|
|
115
117
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
118
|
+
for beta in [False, True]:
|
|
119
|
+
common_configs = {
|
|
120
|
+
"kernel_name": f"dyt_beta={beta}",
|
|
121
|
+
"x_name": "hidden_size",
|
|
122
|
+
"x_label": "hidden_size",
|
|
123
|
+
"x_values": [1024 * i for i in range(1, 17)],
|
|
124
|
+
"kernel_providers": ["liger", "torch", "torch_compile"],
|
|
125
|
+
"extra_benchmark_configs": [{"BT": 4096, "dtype": torch.bfloat16, "beta": beta}],
|
|
126
|
+
"overwrite": args.overwrite,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
run_benchmarks(
|
|
130
|
+
bench_test_fn=bench_speed_dyt,
|
|
131
|
+
kernel_operation_modes=["forward", "backward", "full"],
|
|
132
|
+
metric_name="speed",
|
|
133
|
+
metric_unit="ms",
|
|
134
|
+
**common_configs,
|
|
135
|
+
)
|
|
136
|
+
run_benchmarks(
|
|
137
|
+
bench_test_fn=bench_memory_dyt,
|
|
138
|
+
kernel_operation_modes=["full"],
|
|
139
|
+
metric_name="memory",
|
|
140
|
+
metric_unit="MB",
|
|
141
|
+
**common_configs,
|
|
142
|
+
)
|