liger-kernel-nightly 0.5.3.dev20250221011147__tar.gz → 0.5.3.dev20250221162633__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/PKG-INFO +2 -1
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/README.md +1 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel_nightly.egg-info/PKG-INFO +2 -1
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
- liger_kernel_nightly-0.5.3.dev20250221162633/test/transformers/test_flex_attention.py +283 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/Makefile +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/setup.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.3.
|
3
|
+
Version: 0.5.3.dev20250221162633
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -313,6 +313,7 @@ loss.backward()
|
|
313
313
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
314
314
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
315
315
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
316
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
316
317
|
|
317
318
|
|
318
319
|
## Low-level APIs
|
@@ -265,6 +265,7 @@ loss.backward()
|
|
265
265
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
266
266
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
267
267
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
268
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
268
269
|
|
269
270
|
|
270
271
|
## Low-level APIs
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.3.
|
7
|
+
version = "0.5.3.dev20250221162633"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.3.
|
3
|
+
Version: 0.5.3.dev20250221162633
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -313,6 +313,7 @@ loss.backward()
|
|
313
313
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
314
314
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
315
315
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
316
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
316
317
|
|
317
318
|
|
318
319
|
## Low-level APIs
|
@@ -204,6 +204,7 @@ test/resources/tiny_shakespeare_tokenized/state.json
|
|
204
204
|
test/transformers/test_auto_model.py
|
205
205
|
test/transformers/test_cross_entropy.py
|
206
206
|
test/transformers/test_embedding.py
|
207
|
+
test/transformers/test_flex_attention.py
|
207
208
|
test/transformers/test_fused_linear_cross_entropy.py
|
208
209
|
test/transformers/test_fused_linear_jsd.py
|
209
210
|
test/transformers/test_geglu.py
|
@@ -0,0 +1,283 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
from test.utils import assert_verbose_allclose
|
6
|
+
from test.utils import set_seed
|
7
|
+
from test.utils import supports_bfloat16
|
8
|
+
from torch.nn.attention.flex_attention import create_block_mask
|
9
|
+
from torch.nn.attention.flex_attention import create_mask
|
10
|
+
from torch.nn.attention.flex_attention import flex_attention
|
11
|
+
|
12
|
+
from liger_kernel.utils import infer_device
|
13
|
+
|
14
|
+
|
15
|
+
def causal_mask(b, h, q_idx, kv_idx):
|
16
|
+
return q_idx >= kv_idx
|
17
|
+
|
18
|
+
|
19
|
+
def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index):
|
20
|
+
return (~((q_idx >= rejected_index[b]) & (chosen_index[b] <= kv_idx) & (kv_idx < rejected_index[b]))) & (
|
21
|
+
q_idx >= kv_idx
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
device = infer_device()
|
26
|
+
set_seed(42)
|
27
|
+
|
28
|
+
|
29
|
+
def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cuda"):
|
30
|
+
"""
|
31
|
+
Test attention mechanisms with various implementations.
|
32
|
+
|
33
|
+
Parameters:
|
34
|
+
B (int): Batch size
|
35
|
+
H (int): Number of attention heads
|
36
|
+
S (int): Sequence length
|
37
|
+
D (int): Hidden dimension per head
|
38
|
+
mask_func: A function that generates custom attention mask
|
39
|
+
dtype: Data type for computation
|
40
|
+
atol (float): Absolute tolerance for comparison
|
41
|
+
rtol (float): Relative tolerance for comparison
|
42
|
+
"""
|
43
|
+
torch.manual_seed(0)
|
44
|
+
|
45
|
+
# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
|
46
|
+
query_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
|
47
|
+
key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
|
48
|
+
value_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
|
49
|
+
|
50
|
+
query_flex = query_torch.clone().detach().requires_grad_(True)
|
51
|
+
key_flex = key_torch.clone().detach().requires_grad_(True)
|
52
|
+
value_flex = value_torch.clone().detach().requires_grad_(True)
|
53
|
+
|
54
|
+
block_mask = create_block_mask(mask_func, B, H, S, S, device=device) # Sparsity block mask
|
55
|
+
mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask
|
56
|
+
|
57
|
+
# If you are using a causal mask with FA2, you can enable `is_causal`."
|
58
|
+
# e.g.,
|
59
|
+
# F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
|
60
|
+
|
61
|
+
torch_out = F.scaled_dot_product_attention(query_torch, key_torch, value_torch, attn_mask=mask)
|
62
|
+
|
63
|
+
flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask)
|
64
|
+
|
65
|
+
# Check forward pass
|
66
|
+
assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol)
|
67
|
+
|
68
|
+
grad_out = torch.randn_like(torch_out)
|
69
|
+
torch_out.backward(grad_out)
|
70
|
+
flex_out.backward(grad_out)
|
71
|
+
|
72
|
+
# Check gradients
|
73
|
+
assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol)
|
74
|
+
assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol)
|
75
|
+
assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol)
|
76
|
+
|
77
|
+
|
78
|
+
@pytest.mark.parametrize(
|
79
|
+
"B, H, S, D",
|
80
|
+
[
|
81
|
+
(2, 8, 1024, 32),
|
82
|
+
(3, 12, 2048, 64),
|
83
|
+
],
|
84
|
+
)
|
85
|
+
@pytest.mark.parametrize(
|
86
|
+
"dtype, atol, rtol",
|
87
|
+
[
|
88
|
+
pytest.param(
|
89
|
+
torch.bfloat16,
|
90
|
+
3e-2,
|
91
|
+
5e-1,
|
92
|
+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
93
|
+
),
|
94
|
+
(torch.float16, 1e-2, 5e-3),
|
95
|
+
(torch.float32, 1e-3, 5e-4),
|
96
|
+
],
|
97
|
+
)
|
98
|
+
def test_correctness_flex(B, H, S, D, dtype, atol, rtol):
|
99
|
+
_test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol)
|
100
|
+
|
101
|
+
# Roughly generate custom rejected and chosen indices for each batch
|
102
|
+
chosen_index = torch.randint(0, S // 2, (B,), device="cuda")
|
103
|
+
rejected_index = torch.randint(S // 2, S, (B,), device="cuda")
|
104
|
+
|
105
|
+
def wrapped_prefix_mask(b, h, q_idx, kv_idx):
|
106
|
+
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)
|
107
|
+
|
108
|
+
_test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol)
|
109
|
+
|
110
|
+
|
111
|
+
def _test_correctness_prefix(
|
112
|
+
B=2,
|
113
|
+
H=8,
|
114
|
+
P=512,
|
115
|
+
C=256,
|
116
|
+
R=256,
|
117
|
+
D=32,
|
118
|
+
dtype=torch.float32,
|
119
|
+
atol=1e-3,
|
120
|
+
rtol=5e-4,
|
121
|
+
device="cuda",
|
122
|
+
):
|
123
|
+
"""
|
124
|
+
Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected).
|
125
|
+
The mental model is:
|
126
|
+
|
127
|
+
A. prefix + chosen
|
128
|
+
P
|
129
|
+
P P
|
130
|
+
P P P
|
131
|
+
P P P C
|
132
|
+
P P P C C
|
133
|
+
P P P C C C
|
134
|
+
|
135
|
+
B. prefix + rejected
|
136
|
+
P
|
137
|
+
P P
|
138
|
+
P P P
|
139
|
+
P P P R
|
140
|
+
P P P R R
|
141
|
+
P P P R R R
|
142
|
+
|
143
|
+
C. shared prefix + chosen + rejected
|
144
|
+
P
|
145
|
+
P P
|
146
|
+
P P P
|
147
|
+
P P P C
|
148
|
+
P P P C C
|
149
|
+
P P P C C C
|
150
|
+
P P P R
|
151
|
+
P P P R R
|
152
|
+
P P P R R R
|
153
|
+
|
154
|
+
|
155
|
+
We test them as below to ensure attention value equivalence:
|
156
|
+
1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.)
|
157
|
+
2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.)
|
158
|
+
P P
|
159
|
+
P P = P P
|
160
|
+
P P P P P P
|
161
|
+
|
162
|
+
3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.)
|
163
|
+
C C
|
164
|
+
C C = C C
|
165
|
+
C C C C C C
|
166
|
+
|
167
|
+
4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.)
|
168
|
+
R R
|
169
|
+
R R = R R
|
170
|
+
R R R R R R
|
171
|
+
|
172
|
+
Args:
|
173
|
+
B: batch size
|
174
|
+
H: number of heads
|
175
|
+
P: prefix length
|
176
|
+
C: chosen response length
|
177
|
+
R: rejected response length
|
178
|
+
D: hidden dimension per head
|
179
|
+
"""
|
180
|
+
torch.manual_seed(0)
|
181
|
+
|
182
|
+
# Total sequence length for shared version
|
183
|
+
S = P + C + R
|
184
|
+
|
185
|
+
# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
|
186
|
+
query = torch.randn(B, H, S, D, device=device, dtype=dtype)
|
187
|
+
key = torch.randn(B, H, S, D, device=device, dtype=dtype)
|
188
|
+
value = torch.randn(B, H, S, D, device=device, dtype=dtype)
|
189
|
+
|
190
|
+
# Split tensors for separate computation
|
191
|
+
query_prefix = query[:, :, :P, :]
|
192
|
+
key_prefix = key[:, :, :P, :]
|
193
|
+
value_prefix = value[:, :, :P, :]
|
194
|
+
|
195
|
+
query_chosen = query[:, :, P : P + C, :]
|
196
|
+
key_chosen = key[:, :, P : P + C, :]
|
197
|
+
value_chosen = value[:, :, P : P + C, :]
|
198
|
+
|
199
|
+
query_rejected = query[:, :, P + C :, :]
|
200
|
+
key_rejected = key[:, :, P + C :, :]
|
201
|
+
value_rejected = value[:, :, P + C :, :]
|
202
|
+
|
203
|
+
chosen_index = torch.full((B,), P + C, device=device)
|
204
|
+
rejected_index = torch.full((B,), S, device=device)
|
205
|
+
|
206
|
+
def wrapped_prefix_mask(b, h, q_idx, kv_idx):
|
207
|
+
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)
|
208
|
+
|
209
|
+
block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device)
|
210
|
+
shared_out = flex_attention(query, key, value, block_mask=block_mask)
|
211
|
+
|
212
|
+
# Compute attention for prefix + chosen separately
|
213
|
+
PC = P + C
|
214
|
+
query_pc = torch.cat([query_prefix, query_chosen], dim=2)
|
215
|
+
key_pc = torch.cat([key_prefix, key_chosen], dim=2)
|
216
|
+
value_pc = torch.cat([value_prefix, value_chosen], dim=2)
|
217
|
+
|
218
|
+
def causal_mask(b, h, q_idx, kv_idx):
|
219
|
+
return q_idx >= kv_idx
|
220
|
+
|
221
|
+
pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device)
|
222
|
+
pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask)
|
223
|
+
|
224
|
+
# Compute attention for prefix + rejected separately
|
225
|
+
PR = P + R
|
226
|
+
query_pr = torch.cat([query_prefix, query_rejected], dim=2)
|
227
|
+
key_pr = torch.cat([key_prefix, key_rejected], dim=2)
|
228
|
+
value_pr = torch.cat([value_prefix, value_rejected], dim=2)
|
229
|
+
|
230
|
+
pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device)
|
231
|
+
pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask)
|
232
|
+
|
233
|
+
shared_prefix = shared_out[:, :, :P, :P]
|
234
|
+
shared_chosen = shared_out[:, :, P : P + C, P : P + C]
|
235
|
+
shared_rejected = shared_out[:, :, P + C :, P + C :]
|
236
|
+
|
237
|
+
separate_prefix_c = pc_out[:, :, :P, :P]
|
238
|
+
separate_chosen = pc_out[:, :, P:, P:]
|
239
|
+
separate_prefix_r = pr_out[:, :, :P, :P]
|
240
|
+
separate_rejected = pr_out[:, :, P:, P:]
|
241
|
+
|
242
|
+
# Verify prefix outputs are identical
|
243
|
+
assert torch.allclose(
|
244
|
+
shared_prefix, separate_prefix_c, atol=atol, rtol=rtol
|
245
|
+
), "Prefix attention from shared computation doesn't match prefix+chosen computation"
|
246
|
+
assert torch.allclose(
|
247
|
+
shared_prefix, separate_prefix_r, atol=atol, rtol=rtol
|
248
|
+
), "Prefix attention from shared computation doesn't match prefix+rejected computation"
|
249
|
+
|
250
|
+
# Verify chosen and rejected outputs
|
251
|
+
assert torch.allclose(
|
252
|
+
shared_chosen, separate_chosen, atol=atol, rtol=rtol
|
253
|
+
), "Chosen response attention doesn't match between shared and separate computation"
|
254
|
+
assert torch.allclose(
|
255
|
+
shared_rejected, separate_rejected, atol=atol, rtol=rtol
|
256
|
+
), "Rejected response attention doesn't match between shared and separate computation"
|
257
|
+
|
258
|
+
print("All attention values match between shared and separate computations!")
|
259
|
+
|
260
|
+
|
261
|
+
@pytest.mark.parametrize(
|
262
|
+
"B, H, P, C, R, D",
|
263
|
+
[
|
264
|
+
(2, 8, 512, 256, 256, 32),
|
265
|
+
(3, 12, 1024, 512, 512, 64),
|
266
|
+
],
|
267
|
+
)
|
268
|
+
@pytest.mark.parametrize(
|
269
|
+
"dtype, atol, rtol",
|
270
|
+
[
|
271
|
+
pytest.param(
|
272
|
+
torch.bfloat16,
|
273
|
+
3e-2,
|
274
|
+
5e-1,
|
275
|
+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
276
|
+
),
|
277
|
+
(torch.float16, 1e-2, 5e-3),
|
278
|
+
(torch.float32, 1e-3, 5e-4),
|
279
|
+
],
|
280
|
+
)
|
281
|
+
def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol):
|
282
|
+
"""Parametrized test for different configurations"""
|
283
|
+
_test_correctness_prefix(B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221162633}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|