liger-kernel-nightly 0.5.5.dev20250320214749__tar.gz → 0.5.5.dev20250324181221__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/PKG-INFO +4 -4
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/README.md +3 -3
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/benchmarks_visualizer.py +2 -2
- liger_kernel_nightly-0.5.5.dev20250324181221/benchmark/scripts/benchmark_dyt.py +139 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/pyproject.toml +1 -1
- liger_kernel_nightly-0.5.5.dev20250324181221/src/liger_kernel/ops/dyt.py +225 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/__init__.py +1 -0
- liger_kernel_nightly-0.5.5.dev20250324181221/src/liger_kernel/transformers/dyt.py +20 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/functional.py +5 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/PKG-INFO +4 -4
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
- liger_kernel_nightly-0.5.5.dev20250324181221/test/transformers/test_dyt.py +136 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/Makefile +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/setup.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.5.5.
|
|
3
|
+
Version: 0.5.5.dev20250324181221
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -176,7 +176,7 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
176
176
|
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
|
|
177
177
|
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
|
|
178
178
|
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
|
|
179
|
-
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
|
|
179
|
+
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
|
|
180
180
|
|
|
181
181
|
## Installation
|
|
182
182
|
|
|
@@ -386,8 +386,8 @@ loss.backward()
|
|
|
386
386
|
## Contact
|
|
387
387
|
|
|
388
388
|
- For issues, create a Github ticket in this repository
|
|
389
|
-
- For open discussion, join [our discord channel](https://discord.
|
|
390
|
-
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
389
|
+
- For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
|
|
390
|
+
- For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
|
|
391
391
|
|
|
392
392
|
## Cite this work
|
|
393
393
|
|
|
@@ -128,7 +128,7 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
128
128
|
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
|
|
129
129
|
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
|
|
130
130
|
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
|
|
131
|
-
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
|
|
131
|
+
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
|
|
132
132
|
|
|
133
133
|
## Installation
|
|
134
134
|
|
|
@@ -338,8 +338,8 @@ loss.backward()
|
|
|
338
338
|
## Contact
|
|
339
339
|
|
|
340
340
|
- For issues, create a Github ticket in this repository
|
|
341
|
-
- For open discussion, join [our discord channel](https://discord.
|
|
342
|
-
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
341
|
+
- For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
|
|
342
|
+
- For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
|
|
343
343
|
|
|
344
344
|
## Cite this work
|
|
345
345
|
|
|
@@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import seaborn as sns
|
|
10
10
|
|
|
11
|
-
DATA_PATH = "data/all_benchmark_data.csv"
|
|
12
|
-
VISUALIZATIONS_PATH = "visualizations/"
|
|
11
|
+
DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv"))
|
|
12
|
+
VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
@dataclass
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
|
|
7
|
+
from utils import QUANTILES
|
|
8
|
+
from utils import SingleBenchmarkRunInput
|
|
9
|
+
from utils import SingleBenchmarkRunOutput
|
|
10
|
+
from utils import _test_memory
|
|
11
|
+
from utils import parse_benchmark_script_args
|
|
12
|
+
from utils import run_benchmarks
|
|
13
|
+
|
|
14
|
+
from liger_kernel.utils import infer_device
|
|
15
|
+
|
|
16
|
+
device = infer_device()
|
|
17
|
+
|
|
18
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
22
|
+
from test.transformers.test_dyt import LigerDyT
|
|
23
|
+
from test.transformers.test_dyt import TorchDyT
|
|
24
|
+
|
|
25
|
+
BT = input.x
|
|
26
|
+
provider = input.kernel_provider
|
|
27
|
+
mode = input.kernel_operation_mode
|
|
28
|
+
extra_benchmark_config = input.extra_benchmark_config
|
|
29
|
+
hidden_size = extra_benchmark_config["hidden_size"]
|
|
30
|
+
dtype = extra_benchmark_config["dtype"]
|
|
31
|
+
|
|
32
|
+
x_shape = (BT, hidden_size)
|
|
33
|
+
torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
|
|
34
|
+
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
|
|
35
|
+
triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
|
|
36
|
+
|
|
37
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
|
38
|
+
dy = torch.randn_like(x)
|
|
39
|
+
x.requires_grad_(True)
|
|
40
|
+
|
|
41
|
+
def fwd():
|
|
42
|
+
if provider == "liger":
|
|
43
|
+
return triton_dyt(x)
|
|
44
|
+
elif provider == "torch":
|
|
45
|
+
return torch_dyt(x)
|
|
46
|
+
elif provider == "torch_compile":
|
|
47
|
+
return torch_compile_dyt(x)
|
|
48
|
+
|
|
49
|
+
if mode == "forward":
|
|
50
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
|
|
51
|
+
elif mode == "backward":
|
|
52
|
+
y = fwd()
|
|
53
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
|
54
|
+
lambda: y.backward(dy, retain_graph=True),
|
|
55
|
+
quantiles=QUANTILES,
|
|
56
|
+
grad_to_none=[x],
|
|
57
|
+
rep=500,
|
|
58
|
+
)
|
|
59
|
+
elif mode == "full":
|
|
60
|
+
|
|
61
|
+
def full():
|
|
62
|
+
y = fwd()
|
|
63
|
+
y.backward(dy)
|
|
64
|
+
|
|
65
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
|
|
66
|
+
|
|
67
|
+
return SingleBenchmarkRunOutput(
|
|
68
|
+
y_20=ms_20,
|
|
69
|
+
y_50=ms_50,
|
|
70
|
+
y_80=ms_80,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
75
|
+
from test.transformers.test_dyt import LigerDyT
|
|
76
|
+
from test.transformers.test_dyt import TorchDyT
|
|
77
|
+
|
|
78
|
+
BT = input.x
|
|
79
|
+
provider = input.kernel_provider
|
|
80
|
+
extra_benchmark_config = input.extra_benchmark_config
|
|
81
|
+
hidden_size = extra_benchmark_config["hidden_size"]
|
|
82
|
+
dtype = extra_benchmark_config["dtype"]
|
|
83
|
+
|
|
84
|
+
x_shape = (BT, hidden_size)
|
|
85
|
+
torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
|
|
86
|
+
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
|
|
87
|
+
triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
|
|
88
|
+
|
|
89
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
|
90
|
+
dy = torch.randn_like(x)
|
|
91
|
+
x.requires_grad_(True)
|
|
92
|
+
|
|
93
|
+
def fwd():
|
|
94
|
+
if provider == "liger":
|
|
95
|
+
return triton_dyt(x)
|
|
96
|
+
elif provider == "torch":
|
|
97
|
+
return torch_dyt(x)
|
|
98
|
+
elif provider == "torch_compile":
|
|
99
|
+
return torch_compile_dyt(x)
|
|
100
|
+
|
|
101
|
+
def full():
|
|
102
|
+
y = fwd()
|
|
103
|
+
y.backward(dy, retain_graph=True)
|
|
104
|
+
|
|
105
|
+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
|
|
106
|
+
return SingleBenchmarkRunOutput(
|
|
107
|
+
y_20=mem_20,
|
|
108
|
+
y_50=mem_50,
|
|
109
|
+
y_80=mem_80,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
if __name__ == "__main__":
|
|
114
|
+
args = parse_benchmark_script_args()
|
|
115
|
+
|
|
116
|
+
common_configs = {
|
|
117
|
+
"kernel_name": "dyt",
|
|
118
|
+
"x_name": "BT",
|
|
119
|
+
"x_label": "batch_size * seq_len",
|
|
120
|
+
"x_values": [2**i for i in range(10, 15)],
|
|
121
|
+
"kernel_providers": ["liger", "torch", "torch_compile"],
|
|
122
|
+
"extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
|
|
123
|
+
"overwrite": args.overwrite,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
run_benchmarks(
|
|
127
|
+
bench_test_fn=bench_speed_dyt,
|
|
128
|
+
kernel_operation_modes=["forward", "backward", "full"],
|
|
129
|
+
metric_name="speed",
|
|
130
|
+
metric_unit="ms",
|
|
131
|
+
**common_configs,
|
|
132
|
+
)
|
|
133
|
+
run_benchmarks(
|
|
134
|
+
bench_test_fn=bench_memory_dyt,
|
|
135
|
+
kernel_operation_modes=["full"],
|
|
136
|
+
metric_name="memory",
|
|
137
|
+
metric_unit="MB",
|
|
138
|
+
**common_configs,
|
|
139
|
+
)
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.5.
|
|
7
|
+
version = "0.5.5.dev20250324181221"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import infer_device
|
|
11
|
+
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
try:
|
|
14
|
+
# typical import path with dispatch available
|
|
15
|
+
from triton.language.extra.libdevice import tanh
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
# for working with NGC containers
|
|
18
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
19
|
+
else:
|
|
20
|
+
from triton.language.math import tanh
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@triton.jit
|
|
24
|
+
def _dyt_fwd_kernel(
|
|
25
|
+
x_ptr,
|
|
26
|
+
x_row_stride,
|
|
27
|
+
alpha_ptr,
|
|
28
|
+
gamma_ptr,
|
|
29
|
+
beta_ptr,
|
|
30
|
+
y_ptr,
|
|
31
|
+
y_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
BLOCK_SIZE: tl.constexpr,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Reference:
|
|
37
|
+
https://arxiv.org/abs/2503.10622
|
|
38
|
+
|
|
39
|
+
Shapes:
|
|
40
|
+
- x: (BT, C)
|
|
41
|
+
- alpha: (1)
|
|
42
|
+
- gamma: (C)
|
|
43
|
+
- beta: (C)
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0)
|
|
46
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = offsets < n_cols
|
|
48
|
+
|
|
49
|
+
x_ptr += row_idx * x_row_stride
|
|
50
|
+
y_ptr += row_idx * y_row_stride
|
|
51
|
+
|
|
52
|
+
alpha = tl.load(alpha_ptr)
|
|
53
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
+
tl.store(y_ptr + offsets, y, mask=mask)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def _dyt_bwd_kernel(
|
|
62
|
+
x_ptr,
|
|
63
|
+
x_row_stride,
|
|
64
|
+
dy_ptr,
|
|
65
|
+
dy_row_stride,
|
|
66
|
+
dx_ptr,
|
|
67
|
+
dx_row_stride,
|
|
68
|
+
alpha_ptr,
|
|
69
|
+
dalpha_ptr,
|
|
70
|
+
gamma_ptr,
|
|
71
|
+
dgamma_ptr,
|
|
72
|
+
dgamma_row_stride,
|
|
73
|
+
n_cols,
|
|
74
|
+
n_rows,
|
|
75
|
+
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
+
BLOCK_SIZE: tl.constexpr,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Reference:
|
|
80
|
+
https://arxiv.org/abs/2503.10622
|
|
81
|
+
|
|
82
|
+
Shapes:
|
|
83
|
+
- x: (BT, C)
|
|
84
|
+
- alpha: (1)
|
|
85
|
+
- gamma: (C)
|
|
86
|
+
- dx: (BT, C)
|
|
87
|
+
- dy: (BT, C)
|
|
88
|
+
- dgamma: (sm_count, C)
|
|
89
|
+
- dalpha: (sm_count,)
|
|
90
|
+
"""
|
|
91
|
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
|
92
|
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
|
93
|
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
|
94
|
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
|
95
|
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
|
96
|
+
# = tanh(alpha * x)
|
|
97
|
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
|
98
|
+
pid = tl.program_id(0)
|
|
99
|
+
|
|
100
|
+
row_start = pid * ROWS_PER_PROGRAM
|
|
101
|
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
|
102
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
103
|
+
mask = offsets < n_cols
|
|
104
|
+
|
|
105
|
+
dalpha = 0.0
|
|
106
|
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
+
|
|
108
|
+
x_ptr += row_start * x_row_stride
|
|
109
|
+
dx_ptr += row_start * dx_row_stride
|
|
110
|
+
dy_ptr += row_start * dy_row_stride
|
|
111
|
+
alpha = tl.load(alpha_ptr)
|
|
112
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
+
|
|
114
|
+
for _ in tl.range(row_start, row_end):
|
|
115
|
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
+
|
|
120
|
+
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
+
dgamma += dy * tanh_ax
|
|
123
|
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
+
|
|
125
|
+
dy_ptr += dy_row_stride
|
|
126
|
+
x_ptr += x_row_stride
|
|
127
|
+
dx_ptr += dx_row_stride
|
|
128
|
+
|
|
129
|
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
+
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
+
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
+
shape = x.shape
|
|
137
|
+
dim = shape[-1]
|
|
138
|
+
x = x.view(-1, dim)
|
|
139
|
+
n_rows, n_cols = x.shape
|
|
140
|
+
y = torch.empty_like(x)
|
|
141
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
142
|
+
_dyt_fwd_kernel[(n_rows,)](
|
|
143
|
+
x_ptr=x,
|
|
144
|
+
alpha_ptr=alpha,
|
|
145
|
+
gamma_ptr=gamma,
|
|
146
|
+
beta_ptr=beta,
|
|
147
|
+
y_ptr=y,
|
|
148
|
+
x_row_stride=x.stride(0),
|
|
149
|
+
y_row_stride=y.stride(0),
|
|
150
|
+
n_cols=n_cols,
|
|
151
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
152
|
+
num_warps=num_warps,
|
|
153
|
+
)
|
|
154
|
+
return y.view(*shape)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
+
shape = dy.shape
|
|
159
|
+
dtype = x.dtype
|
|
160
|
+
dim = shape[-1]
|
|
161
|
+
dy = dy.view(-1, dim)
|
|
162
|
+
x = x.view(-1, dim)
|
|
163
|
+
n_rows, n_cols = dy.shape
|
|
164
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
+
sm_count = 1
|
|
166
|
+
device = infer_device()
|
|
167
|
+
if device == "cuda":
|
|
168
|
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
|
+
elif device == "xpu":
|
|
170
|
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
171
|
+
if n_cols > BLOCK_SIZE:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
dx = torch.empty_like(x, dtype=torch.float32)
|
|
177
|
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
|
178
|
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
|
179
|
+
|
|
180
|
+
grid = (sm_count,)
|
|
181
|
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
|
182
|
+
_dyt_bwd_kernel[grid](
|
|
183
|
+
x_ptr=x,
|
|
184
|
+
x_row_stride=x.stride(0),
|
|
185
|
+
dy_ptr=dy,
|
|
186
|
+
dy_row_stride=dy.stride(0),
|
|
187
|
+
dx_ptr=dx,
|
|
188
|
+
dx_row_stride=dx.stride(0),
|
|
189
|
+
alpha_ptr=alpha,
|
|
190
|
+
dalpha_ptr=_dalpha,
|
|
191
|
+
gamma_ptr=gamma,
|
|
192
|
+
dgamma_ptr=_dgamma,
|
|
193
|
+
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
+
n_cols=n_cols,
|
|
195
|
+
n_rows=n_rows,
|
|
196
|
+
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
)
|
|
200
|
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
+
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LigerDyTFunction(torch.autograd.Function):
|
|
207
|
+
@staticmethod
|
|
208
|
+
@ensure_contiguous
|
|
209
|
+
def forward(ctx, x, alpha, gamma, beta):
|
|
210
|
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
+
ctx.save_for_backward(x, alpha, gamma)
|
|
212
|
+
return y
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
@ensure_contiguous
|
|
216
|
+
def backward(ctx, grad_output):
|
|
217
|
+
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
+
grad_output,
|
|
220
|
+
x,
|
|
221
|
+
alpha,
|
|
222
|
+
gamma,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return (dx, dalpha, dgamma, dbeta)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
2
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
3
|
+
from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
|
|
3
4
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
|
4
5
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
|
5
6
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerDyT(nn.Module):
|
|
8
|
+
def __init__(self, hidden_size, init_alpha=0.5):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.hidden_size = hidden_size
|
|
11
|
+
self.init_alpha = init_alpha
|
|
12
|
+
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
13
|
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
14
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
|
18
|
+
|
|
19
|
+
def extra_repr(self):
|
|
20
|
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}"
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
4
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
5
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
6
7
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
192
193
|
|
|
193
194
|
def liger_swiglu(a, b):
|
|
194
195
|
return LigerSiLUMulFunction.apply(a, b)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def liger_dyt(x, alpha, gamma, beta):
|
|
199
|
+
return LigerDyTFunction.apply(x, alpha, gamma, beta)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.5.5.
|
|
3
|
+
Version: 0.5.5.dev20250324181221
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -176,7 +176,7 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
176
176
|
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
|
|
177
177
|
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
|
|
178
178
|
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
|
|
179
|
-
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
|
|
179
|
+
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
|
|
180
180
|
|
|
181
181
|
## Installation
|
|
182
182
|
|
|
@@ -386,8 +386,8 @@ loss.backward()
|
|
|
386
386
|
## Contact
|
|
387
387
|
|
|
388
388
|
- For issues, create a Github ticket in this repository
|
|
389
|
-
- For open discussion, join [our discord channel](https://discord.
|
|
390
|
-
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
389
|
+
- For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
|
|
390
|
+
- For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
|
|
391
391
|
|
|
392
392
|
## Cite this work
|
|
393
393
|
|
|
@@ -24,6 +24,7 @@ benchmark/scripts/benchmark_cpo_loss.py
|
|
|
24
24
|
benchmark/scripts/benchmark_cross_entropy.py
|
|
25
25
|
benchmark/scripts/benchmark_distill_jsd_loss.py
|
|
26
26
|
benchmark/scripts/benchmark_dpo_loss.py
|
|
27
|
+
benchmark/scripts/benchmark_dyt.py
|
|
27
28
|
benchmark/scripts/benchmark_embedding.py
|
|
28
29
|
benchmark/scripts/benchmark_fused_linear_cross_entropy.py
|
|
29
30
|
benchmark/scripts/benchmark_fused_linear_jsd.py
|
|
@@ -121,6 +122,7 @@ src/liger_kernel/chunked_loss/orpo_loss.py
|
|
|
121
122
|
src/liger_kernel/chunked_loss/simpo_loss.py
|
|
122
123
|
src/liger_kernel/ops/__init__.py
|
|
123
124
|
src/liger_kernel/ops/cross_entropy.py
|
|
125
|
+
src/liger_kernel/ops/dyt.py
|
|
124
126
|
src/liger_kernel/ops/fused_linear_cross_entropy.py
|
|
125
127
|
src/liger_kernel/ops/fused_linear_jsd.py
|
|
126
128
|
src/liger_kernel/ops/geglu.py
|
|
@@ -139,6 +141,7 @@ src/liger_kernel/ops/experimental/mm_int8int2.py
|
|
|
139
141
|
src/liger_kernel/transformers/__init__.py
|
|
140
142
|
src/liger_kernel/transformers/auto_model.py
|
|
141
143
|
src/liger_kernel/transformers/cross_entropy.py
|
|
144
|
+
src/liger_kernel/transformers/dyt.py
|
|
142
145
|
src/liger_kernel/transformers/functional.py
|
|
143
146
|
src/liger_kernel/transformers/fused_linear_cross_entropy.py
|
|
144
147
|
src/liger_kernel/transformers/fused_linear_jsd.py
|
|
@@ -209,6 +212,7 @@ test/resources/tiny_shakespeare_tokenized/dataset_info.json
|
|
|
209
212
|
test/resources/tiny_shakespeare_tokenized/state.json
|
|
210
213
|
test/transformers/test_auto_model.py
|
|
211
214
|
test/transformers/test_cross_entropy.py
|
|
215
|
+
test/transformers/test_dyt.py
|
|
212
216
|
test/transformers/test_embedding.py
|
|
213
217
|
test/transformers/test_flex_attention.py
|
|
214
218
|
test/transformers/test_fused_linear_cross_entropy.py
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
from test.utils import assert_verbose_allclose
|
|
6
|
+
from test.utils import infer_device
|
|
7
|
+
from test.utils import set_seed
|
|
8
|
+
from test.utils import supports_bfloat16
|
|
9
|
+
|
|
10
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
11
|
+
from liger_kernel.transformers.dyt import LigerDyT
|
|
12
|
+
from liger_kernel.transformers.functional import liger_dyt
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TorchDyT(nn.Module):
|
|
16
|
+
def __init__(self, hidden_size, init_alpha=0.5):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
19
|
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
20
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
return self.gamma * torch.tanh(self.alpha * x) + self.beta
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
set_seed(42)
|
|
27
|
+
device = infer_device()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.mark.parametrize("init_alpha", [0.5, 0.2, 1.0])
|
|
31
|
+
@pytest.mark.parametrize(
|
|
32
|
+
"B, T, hidden_size",
|
|
33
|
+
[
|
|
34
|
+
(2, 8, 4096),
|
|
35
|
+
(4, 16, 2048),
|
|
36
|
+
(1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
|
|
37
|
+
(3, 7, 256), # Prime numbers for batch/seq
|
|
38
|
+
],
|
|
39
|
+
)
|
|
40
|
+
@pytest.mark.parametrize(
|
|
41
|
+
"dtype, atol, rtol",
|
|
42
|
+
[
|
|
43
|
+
(torch.float32, 1e-5, 1e-5),
|
|
44
|
+
],
|
|
45
|
+
)
|
|
46
|
+
def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol):
|
|
47
|
+
_input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
|
|
48
|
+
|
|
49
|
+
x1 = _input.clone().requires_grad_(True)
|
|
50
|
+
x2 = _input.clone().requires_grad_(True)
|
|
51
|
+
|
|
52
|
+
# initialize weights
|
|
53
|
+
alpha = torch.randn(1, device=device, dtype=dtype)
|
|
54
|
+
gamma = torch.randn(hidden_size, device=device, dtype=dtype)
|
|
55
|
+
beta = torch.randn(hidden_size, device=device, dtype=dtype)
|
|
56
|
+
|
|
57
|
+
torch_dyt = TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
|
|
58
|
+
torch_dyt.alpha.data = alpha.clone()
|
|
59
|
+
torch_dyt.gamma.data = gamma.clone()
|
|
60
|
+
torch_dyt.beta.data = beta.clone()
|
|
61
|
+
|
|
62
|
+
liger_dyt = LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
|
|
63
|
+
liger_dyt.alpha.data = alpha.clone()
|
|
64
|
+
liger_dyt.gamma.data = gamma.clone()
|
|
65
|
+
liger_dyt.beta.data = beta.clone()
|
|
66
|
+
|
|
67
|
+
torch_output = torch_dyt(x1)
|
|
68
|
+
liger_output = liger_dyt(x2)
|
|
69
|
+
|
|
70
|
+
assert_verbose_allclose(torch_output, liger_output, rtol=rtol, atol=atol)
|
|
71
|
+
|
|
72
|
+
grad_output = torch.randn_like(_input)
|
|
73
|
+
torch_output.backward(grad_output)
|
|
74
|
+
liger_output.backward(grad_output)
|
|
75
|
+
|
|
76
|
+
assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
|
|
77
|
+
assert_verbose_allclose(torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol)
|
|
78
|
+
assert_verbose_allclose(torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol)
|
|
79
|
+
assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.parametrize(
|
|
83
|
+
"B, T, hidden_size",
|
|
84
|
+
[
|
|
85
|
+
(2, 8, 4096),
|
|
86
|
+
(4, 16, 2048),
|
|
87
|
+
(1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
|
|
88
|
+
(3, 7, 256), # Prime numbers for batch/seq
|
|
89
|
+
],
|
|
90
|
+
)
|
|
91
|
+
@pytest.mark.parametrize(
|
|
92
|
+
"dtype, atol, rtol",
|
|
93
|
+
[
|
|
94
|
+
# atol is for small values: they have more difference, so set atol higher
|
|
95
|
+
# rtol is for larger values: they are very close, so set rtol lower
|
|
96
|
+
(torch.float32, 1e-5, 1e-5),
|
|
97
|
+
pytest.param(
|
|
98
|
+
torch.bfloat16,
|
|
99
|
+
1e-8,
|
|
100
|
+
5e-2,
|
|
101
|
+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
|
102
|
+
),
|
|
103
|
+
],
|
|
104
|
+
)
|
|
105
|
+
def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
|
|
106
|
+
_input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
|
|
107
|
+
|
|
108
|
+
x1 = _input.clone().requires_grad_(True)
|
|
109
|
+
x2 = _input.clone().requires_grad_(True)
|
|
110
|
+
|
|
111
|
+
# initialize weights
|
|
112
|
+
alpha = torch.randn(1, device=device, dtype=dtype)
|
|
113
|
+
gamma = torch.randn(hidden_size, device=device, dtype=dtype)
|
|
114
|
+
beta = torch.randn(hidden_size, device=device, dtype=dtype)
|
|
115
|
+
|
|
116
|
+
alpha1 = alpha.clone().requires_grad_(True)
|
|
117
|
+
gamma1 = gamma.clone().requires_grad_(True)
|
|
118
|
+
beta1 = beta.clone().requires_grad_(True)
|
|
119
|
+
|
|
120
|
+
alpha2 = alpha.clone().requires_grad_(True)
|
|
121
|
+
gamma2 = gamma.clone().requires_grad_(True)
|
|
122
|
+
beta2 = beta.clone().requires_grad_(True)
|
|
123
|
+
|
|
124
|
+
output1 = liger_dyt(x1, alpha=alpha1, gamma=gamma1, beta=beta1)
|
|
125
|
+
output2 = LigerDyTFunction.apply(x2, alpha2, gamma2, beta2)
|
|
126
|
+
|
|
127
|
+
assert_verbose_allclose(output1, output2, rtol=rtol, atol=atol)
|
|
128
|
+
|
|
129
|
+
grad_output = torch.randn_like(_input)
|
|
130
|
+
output1.backward(grad_output)
|
|
131
|
+
output2.backward(grad_output)
|
|
132
|
+
|
|
133
|
+
assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
|
|
134
|
+
assert_verbose_allclose(alpha1.grad, alpha2.grad, rtol=rtol, atol=atol)
|
|
135
|
+
assert_verbose_allclose(gamma1.grad, gamma2.grad, rtol=rtol, atol=atol)
|
|
136
|
+
assert_verbose_allclose(beta1.grad, beta2.grad, rtol=rtol, atol=atol)
|
|
File without changes
|