liger-kernel-nightly 0.5.3.dev20250221230243__tar.gz → 0.5.3.dev20250224175624__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_rope.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel_nightly-0.5.3.dev20250224175624/src/liger_kernel/utils.py +62 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_layer_norm.py +20 -5
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_rope.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/utils.py +0 -47
- liger_kernel_nightly-0.5.3.dev20250221230243/src/liger_kernel/utils.py +0 -15
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/Makefile +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/setup.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
|
-
from test.utils import transformers_version_dispatch
|
|
5
4
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
6
5
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
7
6
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
@@ -14,6 +13,7 @@ from utils import run_benchmarks
|
|
|
14
13
|
|
|
15
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
16
15
|
from liger_kernel.utils import infer_device
|
|
16
|
+
from liger_kernel.utils import transformers_version_dispatch
|
|
17
17
|
|
|
18
18
|
device = infer_device()
|
|
19
19
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.3.
|
|
7
|
+
version = "0.5.3.dev20250224175624"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
|
|
|
57
57
|
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
|
|
58
58
|
|
|
59
59
|
mean = tl.sum(X_row, axis=0) / n_cols
|
|
60
|
-
|
|
60
|
+
Xmm = tl.where(mask, X_row - mean, 0)
|
|
61
|
+
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
|
|
61
62
|
rstd = rsqrt(var + eps)
|
|
62
63
|
|
|
63
64
|
tl.store(Mean_ptr, mean)
|
|
64
65
|
tl.store(RSTD_ptr, rstd)
|
|
65
66
|
|
|
66
|
-
Y_row =
|
|
67
|
+
Y_row = Xmm * rstd * W_row + B_row
|
|
67
68
|
|
|
68
69
|
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
69
70
|
|
|
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
147
148
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
148
149
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
149
150
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
151
|
+
if X.shape[1] != W.shape[0]:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
154
|
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
|
155
|
+
)
|
|
153
156
|
|
|
154
157
|
_layer_norm_forward_kernel[(n_rows,)](
|
|
155
158
|
Y,
|
|
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
190
193
|
|
|
191
194
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
192
195
|
if n_cols > BLOCK_SIZE:
|
|
193
|
-
raise RuntimeError(
|
|
196
|
+
raise RuntimeError(
|
|
197
|
+
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
198
|
+
)
|
|
194
199
|
|
|
195
200
|
rows_per_program = math.ceil(n_rows / sm_count)
|
|
196
201
|
grid = (sm_count,)
|
|
197
|
-
triton_dtype =
|
|
202
|
+
triton_dtype = (
|
|
203
|
+
tl.float32
|
|
204
|
+
if X.dtype == torch.float32
|
|
205
|
+
else tl.bfloat16
|
|
206
|
+
if X.dtype == torch.bfloat16
|
|
207
|
+
else tl.float16
|
|
208
|
+
if X.dtype == torch.float16
|
|
209
|
+
else tl.float32 # fallback to float32 for other types
|
|
210
|
+
)
|
|
198
211
|
_layer_norm_backward_kernel[grid](
|
|
199
212
|
X,
|
|
200
213
|
W,
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def infer_device():
|
|
5
|
+
"""
|
|
6
|
+
Get current device name based on available devices
|
|
7
|
+
"""
|
|
8
|
+
if torch.cuda.is_available():
|
|
9
|
+
return "cuda"
|
|
10
|
+
elif torch.xpu.is_available():
|
|
11
|
+
return "xpu"
|
|
12
|
+
elif torch.hip.is_available():
|
|
13
|
+
return "hip"
|
|
14
|
+
else:
|
|
15
|
+
return "cpu"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def transformers_version_dispatch(
|
|
19
|
+
required_version: str,
|
|
20
|
+
before_fn,
|
|
21
|
+
after_fn,
|
|
22
|
+
before_args: tuple = (),
|
|
23
|
+
after_args: tuple = (),
|
|
24
|
+
before_kwargs: dict = None,
|
|
25
|
+
after_kwargs: dict = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Dispatches to different functions based on package version comparison.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
required_version: Version to compare against (e.g. "4.48.0")
|
|
32
|
+
before_fn: Function to call if package_version < required_version
|
|
33
|
+
after_fn: Function to call if package_version >= required_version
|
|
34
|
+
before_args: Positional arguments for before_fn
|
|
35
|
+
after_args: Positional arguments for after_fn
|
|
36
|
+
before_kwargs: Keyword arguments for before_fn
|
|
37
|
+
after_kwargs: Keyword arguments for after_fn
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Result from either before_fn or after_fn
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> rotary_emb = transformers_version_dispatch(
|
|
44
|
+
... "4.48.0",
|
|
45
|
+
... LlamaRotaryEmbedding,
|
|
46
|
+
... LlamaRotaryEmbedding,
|
|
47
|
+
... before_args=(head_dim,),
|
|
48
|
+
... after_args=(LlamaConfig(head_dim=head_dim),),
|
|
49
|
+
... before_kwargs={'device': device},
|
|
50
|
+
... after_kwargs={'device': device}
|
|
51
|
+
... )
|
|
52
|
+
"""
|
|
53
|
+
from packaging import version
|
|
54
|
+
from transformers import __version__ as transformers_version
|
|
55
|
+
|
|
56
|
+
before_kwargs = before_kwargs or {}
|
|
57
|
+
after_kwargs = after_kwargs or {}
|
|
58
|
+
|
|
59
|
+
if version.parse(transformers_version) < version.parse(required_version):
|
|
60
|
+
return before_fn(*before_args, **before_kwargs)
|
|
61
|
+
else:
|
|
62
|
+
return after_fn(*after_args, **after_kwargs)
|
|
@@ -14,6 +14,8 @@ device = infer_device()
|
|
|
14
14
|
[
|
|
15
15
|
(2, 8, 64),
|
|
16
16
|
(4, 16, 128),
|
|
17
|
+
(1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
|
|
18
|
+
(3, 7, 256), # Prime numbers for batch/seq
|
|
17
19
|
],
|
|
18
20
|
)
|
|
19
21
|
@pytest.mark.parametrize(
|
|
@@ -22,7 +24,15 @@ device = infer_device()
|
|
|
22
24
|
(torch.float32, 1e-5, 1e-5),
|
|
23
25
|
],
|
|
24
26
|
)
|
|
25
|
-
def test_liger_layer_norm(
|
|
27
|
+
def test_liger_layer_norm(
|
|
28
|
+
batch_size: int,
|
|
29
|
+
seq_len: int,
|
|
30
|
+
hidden_size: int,
|
|
31
|
+
dtype: torch.dtype,
|
|
32
|
+
atol: float,
|
|
33
|
+
rtol: float,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Test basic layer norm functionality against PyTorch implementation."""
|
|
26
36
|
torch.manual_seed(0)
|
|
27
37
|
|
|
28
38
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
|
|
@@ -64,7 +74,15 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol):
|
|
|
64
74
|
(torch.float32, 1e-5, 1e-5),
|
|
65
75
|
],
|
|
66
76
|
)
|
|
67
|
-
def test_liger_layer_norm_functional(
|
|
77
|
+
def test_liger_layer_norm_functional(
|
|
78
|
+
hidden_size: int,
|
|
79
|
+
batch_size: int,
|
|
80
|
+
seq_len: int,
|
|
81
|
+
dtype: torch.dtype,
|
|
82
|
+
atol: float,
|
|
83
|
+
rtol: float,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Test functional layer norm interface against autograd function."""
|
|
68
86
|
torch.manual_seed(0)
|
|
69
87
|
|
|
70
88
|
input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
|
|
@@ -73,12 +91,10 @@ def test_liger_layer_norm_functional(hidden_size, batch_size, seq_len, dtype, at
|
|
|
73
91
|
x2 = input.clone().requires_grad_(True)
|
|
74
92
|
|
|
75
93
|
w = torch.randn(hidden_size, device=device, dtype=dtype)
|
|
76
|
-
|
|
77
94
|
w1 = w.clone().requires_grad_(True)
|
|
78
95
|
w2 = w.clone().requires_grad_(True)
|
|
79
96
|
|
|
80
97
|
b = torch.randn(hidden_size, device=device, dtype=dtype)
|
|
81
|
-
|
|
82
98
|
b1 = b.clone().requires_grad_(True)
|
|
83
99
|
b2 = b.clone().requires_grad_(True)
|
|
84
100
|
|
|
@@ -88,7 +104,6 @@ def test_liger_layer_norm_functional(hidden_size, batch_size, seq_len, dtype, at
|
|
|
88
104
|
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
|
|
89
105
|
|
|
90
106
|
grad_output = torch.randn_like(y2)
|
|
91
|
-
|
|
92
107
|
y1.backward(grad_output, retain_graph=True)
|
|
93
108
|
y2.backward(grad_output, retain_graph=True)
|
|
94
109
|
|
|
@@ -2,7 +2,6 @@ import pytest
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
from test.utils import supports_bfloat16
|
|
5
|
-
from test.utils import transformers_version_dispatch
|
|
6
5
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
7
6
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
8
7
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
@@ -11,6 +10,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
|
|
|
11
10
|
from liger_kernel.transformers.functional import liger_rope
|
|
12
11
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
13
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import transformers_version_dispatch
|
|
14
14
|
|
|
15
15
|
device = infer_device()
|
|
16
16
|
|
|
@@ -214,53 +214,6 @@ def supports_bfloat16():
|
|
|
214
214
|
return False
|
|
215
215
|
|
|
216
216
|
|
|
217
|
-
def transformers_version_dispatch(
|
|
218
|
-
required_version: str,
|
|
219
|
-
before_fn,
|
|
220
|
-
after_fn,
|
|
221
|
-
before_args: tuple = (),
|
|
222
|
-
after_args: tuple = (),
|
|
223
|
-
before_kwargs: dict = None,
|
|
224
|
-
after_kwargs: dict = None,
|
|
225
|
-
):
|
|
226
|
-
"""
|
|
227
|
-
Dispatches to different functions based on package version comparison.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
required_version: Version to compare against (e.g. "4.48.0")
|
|
231
|
-
before_fn: Function to call if package_version < required_version
|
|
232
|
-
after_fn: Function to call if package_version >= required_version
|
|
233
|
-
before_args: Positional arguments for before_fn
|
|
234
|
-
after_args: Positional arguments for after_fn
|
|
235
|
-
before_kwargs: Keyword arguments for before_fn
|
|
236
|
-
after_kwargs: Keyword arguments for after_fn
|
|
237
|
-
|
|
238
|
-
Returns:
|
|
239
|
-
Result from either before_fn or after_fn
|
|
240
|
-
|
|
241
|
-
Example:
|
|
242
|
-
>>> rotary_emb = transformers_version_dispatch(
|
|
243
|
-
... "4.48.0",
|
|
244
|
-
... LlamaRotaryEmbedding,
|
|
245
|
-
... LlamaRotaryEmbedding,
|
|
246
|
-
... before_args=(head_dim,),
|
|
247
|
-
... after_args=(LlamaConfig(head_dim=head_dim),),
|
|
248
|
-
... before_kwargs={'device': device},
|
|
249
|
-
... after_kwargs={'device': device}
|
|
250
|
-
... )
|
|
251
|
-
"""
|
|
252
|
-
from packaging import version
|
|
253
|
-
from transformers import __version__ as transformers_version
|
|
254
|
-
|
|
255
|
-
before_kwargs = before_kwargs or {}
|
|
256
|
-
after_kwargs = after_kwargs or {}
|
|
257
|
-
|
|
258
|
-
if version.parse(transformers_version) < version.parse(required_version):
|
|
259
|
-
return before_fn(*before_args, **before_kwargs)
|
|
260
|
-
else:
|
|
261
|
-
return after_fn(*after_args, **after_kwargs)
|
|
262
|
-
|
|
263
|
-
|
|
264
217
|
def revert_liger_kernel_to_granite(model_config: MiniModelConfig):
|
|
265
218
|
"""
|
|
266
219
|
Revert all Liger kernel patches applied to Granite.
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def infer_device():
|
|
5
|
-
"""
|
|
6
|
-
Get current device name based on available devices
|
|
7
|
-
"""
|
|
8
|
-
if torch.cuda.is_available():
|
|
9
|
-
return "cuda"
|
|
10
|
-
elif torch.xpu.is_available():
|
|
11
|
-
return "xpu"
|
|
12
|
-
elif torch.hip.is_available():
|
|
13
|
-
return "hip"
|
|
14
|
-
else:
|
|
15
|
-
return "cpu"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|