liger-kernel-nightly 0.6.4.dev20251229100549__tar.gz → 0.6.4.dev20260113013032__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/amd-ci.yml +4 -4
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/benchmark.yml +14 -10
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/docs.yml +3 -3
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/intel-ci.yml +3 -3
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/nvi-ci.yml +6 -6
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/publish-nightly.yml +2 -2
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/workflows/publish-release.yml +2 -2
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/_ascend/ops/__init__.py +18 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/_ascend/ops/geglu.py +34 -12
- liger_kernel_nightly-0.6.4.dev20260113013032/src/liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/_ascend/ops/rope.py +2 -2
- liger_kernel_nightly-0.6.4.dev20260113013032/src/liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel_nightly-0.6.4.dev20260113013032/src/liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/_ascend/ub_manager.py +1 -1
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/group_norm.py +10 -7
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/kl_div.py +8 -11
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/rms_norm.py +31 -22
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/utils.py +12 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/__init__.py +3 -0
- liger_kernel_nightly-0.6.4.dev20260113013032/src/liger_kernel/transformers/model/exaone4.py +136 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/gemma2.py +3 -3
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/gemma3.py +1 -2
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/monkey_patch.py +78 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/tiled_mlp.py +2 -10
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/utils.py +27 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel_nightly.egg-info/SOURCES.txt +5 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/bf16/test_mini_models.py +55 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/bf16/test_mini_models_with_logits.py +56 -1
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/fp32/test_mini_models.py +53 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/fp32/test_mini_models_multimodal.py +18 -4
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/fp32/test_mini_models_with_logits.py +63 -4
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_group_norm.py +5 -11
- liger_kernel_nightly-0.6.4.dev20260113013032/test/transformers/test_llama4_rope.py +149 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_monkey_patch.py +2 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_rms_norm.py +76 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_tiled_mlp.py +37 -56
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/utils.py +11 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/.gitignore +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/Makefile +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_softmax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_tiled_mlp.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/dev/modal/benchmarks.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/index.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/docs/license.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/setup.cfg +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/setup.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/README.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/_ascend/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/backends/registry.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/softmax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/tiled_mlp.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/fsdp.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/falcon_h1.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/glm4.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/glm4v.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/glm4v_moe.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/gpt_oss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/hunyuan_v1.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/internvl.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/llama4.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/olmo3.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/output_classes.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen3.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen3_next.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen3_vl.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/qwen3_vl_moe.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/smollm3.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/model/smolvlm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/softmax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_cosine_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/Qwen/Qwen3-VL-4B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_softmax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.6.4.dev20251229100549 → liger_kernel_nightly-0.6.4.dev20260113013032}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -26,10 +26,10 @@ jobs:
|
|
|
26
26
|
|
|
27
27
|
steps:
|
|
28
28
|
- name: Checkout code
|
|
29
|
-
uses: actions/checkout@
|
|
29
|
+
uses: actions/checkout@v6
|
|
30
30
|
|
|
31
31
|
- name: Set up Python
|
|
32
|
-
uses: actions/setup-python@
|
|
32
|
+
uses: actions/setup-python@v6
|
|
33
33
|
with:
|
|
34
34
|
python-version: '3.10'
|
|
35
35
|
|
|
@@ -50,10 +50,10 @@ jobs:
|
|
|
50
50
|
|
|
51
51
|
steps:
|
|
52
52
|
- name: Checkout code
|
|
53
|
-
uses: actions/checkout@
|
|
53
|
+
uses: actions/checkout@v6
|
|
54
54
|
|
|
55
55
|
- name: Set up Python
|
|
56
|
-
uses: actions/setup-python@
|
|
56
|
+
uses: actions/setup-python@v6
|
|
57
57
|
with:
|
|
58
58
|
python-version: '3.10'
|
|
59
59
|
|
|
@@ -35,22 +35,26 @@ jobs:
|
|
|
35
35
|
OUTPUT_DIR: benchmarks
|
|
36
36
|
OUTPUT_FILENAME: benchmark.csv
|
|
37
37
|
GENERATED_CSV: benchmark/data/all_benchmark_data.csv
|
|
38
|
+
# Sanitize user-controlled inputs by declaring them as environment variables
|
|
39
|
+
# This prevents command injection attacks by filtering dangerous characters
|
|
40
|
+
INPUT_COMMIT_HASH: ${{ github.event.inputs.commit_hash }}
|
|
41
|
+
INPUT_OVERWRITE: ${{ github.event.inputs.overwrite }}
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
steps:
|
|
41
45
|
# Step: Decide the commit hash to use
|
|
42
46
|
# Step: Checkout full history so we can check out any commit
|
|
43
47
|
- name: Checkout full repo history
|
|
44
|
-
uses: actions/checkout@
|
|
48
|
+
uses: actions/checkout@v6
|
|
45
49
|
with:
|
|
46
50
|
fetch-depth: 0 # Important: so we can checkout arbitrary commit
|
|
47
51
|
|
|
48
52
|
- name: Determine commit hash to checkout
|
|
49
53
|
id: choose_commit
|
|
50
54
|
run: |
|
|
51
|
-
if [ "${{ github.event_name}}" == "workflow_dispatch" ] && [ "$
|
|
52
|
-
echo "Using manual input commit: $
|
|
53
|
-
echo "hash=$
|
|
55
|
+
if [ "${{ github.event_name}}" == "workflow_dispatch" ] && [ "$INPUT_COMMIT_HASH" != "main" ]; then
|
|
56
|
+
echo "Using manual input commit: $INPUT_COMMIT_HASH"
|
|
57
|
+
echo "hash=$INPUT_COMMIT_HASH" >> $GITHUB_OUTPUT
|
|
54
58
|
else
|
|
55
59
|
echo "Using latest commit from main"
|
|
56
60
|
echo "hash=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
|
@@ -60,10 +64,10 @@ jobs:
|
|
|
60
64
|
- name: Replace benchmark folder from main (manual only, commit ≠ main)
|
|
61
65
|
if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.commit_hash != 'main' }}
|
|
62
66
|
run: |
|
|
63
|
-
echo "Detected manual trigger with commit_hash = $
|
|
67
|
+
echo "Detected manual trigger with commit_hash = $INPUT_COMMIT_HASH"
|
|
64
68
|
|
|
65
69
|
# Save current branch (detached HEAD at old commit)
|
|
66
|
-
ORIG_COMMIT
|
|
70
|
+
ORIG_COMMIT="$INPUT_COMMIT_HASH"
|
|
67
71
|
|
|
68
72
|
# Fetch and checkout main
|
|
69
73
|
git fetch origin main
|
|
@@ -72,7 +76,7 @@ jobs:
|
|
|
72
76
|
# Save benchmark folder from main
|
|
73
77
|
cp -r benchmark /tmp/benchmark_main
|
|
74
78
|
# Checkout back to target commit
|
|
75
|
-
git checkout $ORIG_COMMIT
|
|
79
|
+
git checkout "$ORIG_COMMIT"
|
|
76
80
|
# Replace old benchmark with one from main
|
|
77
81
|
rm -rf benchmark
|
|
78
82
|
cp -r /tmp/benchmark_main benchmark
|
|
@@ -85,7 +89,7 @@ jobs:
|
|
|
85
89
|
|
|
86
90
|
if curl --output /dev/null --silent --head --fail "$BENCHMARK_URL"; then
|
|
87
91
|
echo "Benchmark already exists for commit $COMMIT_HASH"
|
|
88
|
-
if [ "$
|
|
92
|
+
if [ "$INPUT_OVERWRITE" != "true" ]; then
|
|
89
93
|
echo "Overwrite is false - exiting"
|
|
90
94
|
exit 1
|
|
91
95
|
else
|
|
@@ -96,7 +100,7 @@ jobs:
|
|
|
96
100
|
fi
|
|
97
101
|
|
|
98
102
|
- name: Set up Python
|
|
99
|
-
uses: actions/setup-python@
|
|
103
|
+
uses: actions/setup-python@v6
|
|
100
104
|
with:
|
|
101
105
|
python-version: '3.10'
|
|
102
106
|
|
|
@@ -117,7 +121,7 @@ jobs:
|
|
|
117
121
|
|
|
118
122
|
# Step 5: Checkout gh-pages branch in a subfolderAdd commentMore actions
|
|
119
123
|
- name: Checkout gh-pages
|
|
120
|
-
uses: actions/checkout@
|
|
124
|
+
uses: actions/checkout@v6
|
|
121
125
|
with:
|
|
122
126
|
ref: gh-pages
|
|
123
127
|
path: gh-pages
|
|
@@ -13,16 +13,16 @@ jobs:
|
|
|
13
13
|
deploy:
|
|
14
14
|
runs-on: ubuntu-latest
|
|
15
15
|
steps:
|
|
16
|
-
- uses: actions/checkout@
|
|
16
|
+
- uses: actions/checkout@v6
|
|
17
17
|
- name: Configure Git Credentials
|
|
18
18
|
run: |
|
|
19
19
|
git config user.name github-actions[bot]
|
|
20
20
|
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
|
21
|
-
- uses: actions/setup-python@
|
|
21
|
+
- uses: actions/setup-python@v6
|
|
22
22
|
with:
|
|
23
23
|
python-version: 3.x
|
|
24
24
|
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
|
25
|
-
- uses: actions/cache@
|
|
25
|
+
- uses: actions/cache@v5
|
|
26
26
|
with:
|
|
27
27
|
key: mkdocs-material-${{ env.cache_id }}
|
|
28
28
|
path: .cache
|
|
@@ -26,10 +26,10 @@ jobs:
|
|
|
26
26
|
|
|
27
27
|
steps:
|
|
28
28
|
- name: Checkout code
|
|
29
|
-
uses: actions/checkout@
|
|
29
|
+
uses: actions/checkout@v6
|
|
30
30
|
|
|
31
31
|
- name: Set up Python
|
|
32
|
-
uses: actions/setup-python@
|
|
32
|
+
uses: actions/setup-python@v6
|
|
33
33
|
with:
|
|
34
34
|
python-version: '3.10'
|
|
35
35
|
|
|
@@ -58,7 +58,7 @@ jobs:
|
|
|
58
58
|
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
59
59
|
|
|
60
60
|
- name: Checkout code
|
|
61
|
-
uses: actions/checkout@
|
|
61
|
+
uses: actions/checkout@v6
|
|
62
62
|
|
|
63
63
|
- name: Setup Dependencies
|
|
64
64
|
shell: bash
|
|
@@ -25,10 +25,10 @@ jobs:
|
|
|
25
25
|
|
|
26
26
|
steps:
|
|
27
27
|
- name: Checkout code
|
|
28
|
-
uses: actions/checkout@
|
|
28
|
+
uses: actions/checkout@v6
|
|
29
29
|
|
|
30
30
|
- name: Set up Python
|
|
31
|
-
uses: actions/setup-python@
|
|
31
|
+
uses: actions/setup-python@v6
|
|
32
32
|
with:
|
|
33
33
|
python-version: '3.10'
|
|
34
34
|
|
|
@@ -49,10 +49,10 @@ jobs:
|
|
|
49
49
|
|
|
50
50
|
steps:
|
|
51
51
|
- name: Checkout code
|
|
52
|
-
uses: actions/checkout@
|
|
52
|
+
uses: actions/checkout@v6
|
|
53
53
|
|
|
54
54
|
- name: Set up Python
|
|
55
|
-
uses: actions/setup-python@
|
|
55
|
+
uses: actions/setup-python@v6
|
|
56
56
|
with:
|
|
57
57
|
python-version: '3.10'
|
|
58
58
|
|
|
@@ -75,10 +75,10 @@ jobs:
|
|
|
75
75
|
|
|
76
76
|
steps:
|
|
77
77
|
- name: Checkout code
|
|
78
|
-
uses: actions/checkout@
|
|
78
|
+
uses: actions/checkout@v6
|
|
79
79
|
|
|
80
80
|
- name: Set up Python
|
|
81
|
-
uses: actions/setup-python@
|
|
81
|
+
uses: actions/setup-python@v6
|
|
82
82
|
with:
|
|
83
83
|
python-version: '3.10'
|
|
84
84
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.6.4.
|
|
7
|
+
version = "0.6.4.dev20260113013032"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -17,15 +17,33 @@ If __all__ is not defined, all public symbols will be auto-discovered.
|
|
|
17
17
|
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
|
|
18
18
|
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
|
|
19
19
|
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
|
|
20
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
21
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
|
|
22
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
|
|
20
23
|
from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
|
|
21
24
|
from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
|
|
22
25
|
from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
|
|
26
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
|
|
27
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
|
|
28
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
|
|
29
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
|
|
30
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
|
|
31
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
|
|
23
32
|
|
|
24
33
|
__all__ = [
|
|
25
34
|
"LigerGELUMulFunction",
|
|
26
35
|
"geglu_forward",
|
|
27
36
|
"geglu_backward",
|
|
37
|
+
"LigerQwen2VLMRopeFunction",
|
|
38
|
+
"qwen2vl_mrope_forward",
|
|
39
|
+
"qwen2vl_mrope_backward",
|
|
28
40
|
"LigerRopeFunction",
|
|
29
41
|
"rope_forward",
|
|
30
42
|
"rope_backward",
|
|
43
|
+
"LigerSiLUMulFunction",
|
|
44
|
+
"swiglu_forward",
|
|
45
|
+
"swiglu_backward",
|
|
46
|
+
"LigerTVDLossFunction",
|
|
47
|
+
"tv_distance_forward_triton",
|
|
48
|
+
"tvd_backward_triton",
|
|
31
49
|
]
|
|
@@ -130,20 +130,26 @@ def geglu_forward(a, b):
|
|
|
130
130
|
dtype_size = a.element_size()
|
|
131
131
|
# GEGLU forward tiling strategy:
|
|
132
132
|
# - Calculates maximum safe block size based on UB capacity
|
|
133
|
-
# - Memory analysis:
|
|
134
|
-
# * Inputs:
|
|
135
|
-
# *
|
|
136
|
-
# *
|
|
137
|
-
#
|
|
138
|
-
#
|
|
133
|
+
# - Memory analysis (only buffers that occupy UB, excluding temporary variables):
|
|
134
|
+
# * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes)
|
|
135
|
+
# * Output: c_row (dtype_size bytes)
|
|
136
|
+
# * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers
|
|
137
|
+
# and don't occupy UB since they are only used once
|
|
138
|
+
# * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0
|
|
139
|
+
# * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0
|
|
140
|
+
# - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
|
|
139
141
|
# - shapes: ((n_cols,),)
|
|
140
142
|
# - tiling_dims: (0,) means first dimension can be tiled
|
|
141
143
|
# - Returns: ((block_size,),)
|
|
142
144
|
shapes = ((n_cols,),)
|
|
145
|
+
if dtype_size == 2:
|
|
146
|
+
memory_multiplier = 4.0
|
|
147
|
+
else:
|
|
148
|
+
memory_multiplier = 3.0
|
|
143
149
|
tile_shapes = compute_default_tiling_strategy(
|
|
144
150
|
safety_margin=0.80,
|
|
145
151
|
dtype_size=dtype_size,
|
|
146
|
-
memory_multiplier=
|
|
152
|
+
memory_multiplier=memory_multiplier,
|
|
147
153
|
shapes=shapes,
|
|
148
154
|
tiling_dims=(0,),
|
|
149
155
|
)
|
|
@@ -187,18 +193,34 @@ def geglu_backward(a, b, dc):
|
|
|
187
193
|
dtype_size = dc.element_size()
|
|
188
194
|
# GEGLU backward tiling strategy:
|
|
189
195
|
# - Calculates maximum safe block size based on UB capacity
|
|
190
|
-
# - Memory analysis:
|
|
191
|
-
#
|
|
192
|
-
#
|
|
193
|
-
#
|
|
196
|
+
# - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation)
|
|
197
|
+
# At this point, the following buffers simultaneously occupy UB:
|
|
198
|
+
# 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes
|
|
199
|
+
# 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32)
|
|
200
|
+
# 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes
|
|
201
|
+
# 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104
|
|
202
|
+
# 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98
|
|
203
|
+
# 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109)
|
|
204
|
+
# Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB
|
|
205
|
+
# Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers
|
|
206
|
+
# and don't occupy UB since they are only used once
|
|
207
|
+
# * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4)
|
|
208
|
+
# = 20 bytes/element, ratio = 20/2 = 10.0
|
|
209
|
+
# * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4)
|
|
210
|
+
# = 24 bytes/element, ratio = 24/4 = 6.0
|
|
211
|
+
# - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
|
|
194
212
|
# - shapes: ((n_cols,),)
|
|
195
213
|
# - tiling_dims: (0,) means first dimension can be tiled
|
|
196
214
|
# - Returns: ((block_size,),)
|
|
197
215
|
shapes = ((n_cols,),)
|
|
216
|
+
if dtype_size == 2:
|
|
217
|
+
memory_multiplier = 10.0
|
|
218
|
+
else:
|
|
219
|
+
memory_multiplier = 6.0
|
|
198
220
|
tile_shapes = compute_default_tiling_strategy(
|
|
199
221
|
safety_margin=0.80,
|
|
200
222
|
dtype_size=dtype_size,
|
|
201
|
-
memory_multiplier=
|
|
223
|
+
memory_multiplier=memory_multiplier,
|
|
202
224
|
shapes=shapes,
|
|
203
225
|
tiling_dims=(0,),
|
|
204
226
|
)
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def _triton_qwen2vl_mrope_npu(
|
|
10
|
+
q_ptr,
|
|
11
|
+
q_row_stride,
|
|
12
|
+
k_ptr,
|
|
13
|
+
k_row_stride,
|
|
14
|
+
cos,
|
|
15
|
+
sin,
|
|
16
|
+
sl,
|
|
17
|
+
bs: tl.constexpr,
|
|
18
|
+
n_qh: tl.constexpr,
|
|
19
|
+
n_kh: tl.constexpr,
|
|
20
|
+
hd: tl.constexpr,
|
|
21
|
+
mrope_section_t: tl.constexpr,
|
|
22
|
+
mrope_section_h: tl.constexpr,
|
|
23
|
+
BLOCK_Q: tl.constexpr,
|
|
24
|
+
BLOCK_K: tl.constexpr,
|
|
25
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
26
|
+
):
|
|
27
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
28
|
+
|
|
29
|
+
t_end = mrope_section_t
|
|
30
|
+
h_end = t_end + mrope_section_h
|
|
31
|
+
|
|
32
|
+
t_cos = cos + pid * hd
|
|
33
|
+
h_cos = t_cos + bs * sl * hd
|
|
34
|
+
w_cos = h_cos + bs * sl * hd
|
|
35
|
+
t_sin = sin + pid * hd
|
|
36
|
+
h_sin = t_sin + bs * sl * hd
|
|
37
|
+
w_sin = h_sin + bs * sl * hd
|
|
38
|
+
|
|
39
|
+
q_base = q_ptr + pid * q_row_stride
|
|
40
|
+
k_base = k_ptr + pid * k_row_stride
|
|
41
|
+
|
|
42
|
+
d_idx = tl.arange(0, hd // 2)
|
|
43
|
+
d_mask = d_idx < (hd // 2)
|
|
44
|
+
|
|
45
|
+
pos_mask_t = d_idx < t_end
|
|
46
|
+
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
|
|
47
|
+
|
|
48
|
+
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
|
|
49
|
+
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
|
|
50
|
+
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
|
|
51
|
+
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
|
|
52
|
+
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
|
|
53
|
+
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
|
|
54
|
+
|
|
55
|
+
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
|
|
56
|
+
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
|
|
57
|
+
|
|
58
|
+
for qh_block in range(0, n_qh, BLOCK_Q):
|
|
59
|
+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
|
|
60
|
+
qh_mask = qh_idx < n_qh
|
|
61
|
+
|
|
62
|
+
block_mask = qh_mask[:, None] & d_mask[None, :]
|
|
63
|
+
offsets = qh_idx[:, None] * hd + d_idx[None, :]
|
|
64
|
+
|
|
65
|
+
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
|
|
66
|
+
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
67
|
+
|
|
68
|
+
if not BACKWARD_PASS:
|
|
69
|
+
new_left = q_left * cos_vals - q_right * sin_vals
|
|
70
|
+
new_right = q_right * cos_vals + q_left * sin_vals
|
|
71
|
+
else:
|
|
72
|
+
new_left = q_left * cos_vals + q_right * sin_vals
|
|
73
|
+
new_right = q_right * cos_vals - q_left * sin_vals
|
|
74
|
+
|
|
75
|
+
tl.store(q_base + offsets, new_left, mask=block_mask)
|
|
76
|
+
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
77
|
+
|
|
78
|
+
for kh_block in range(0, n_kh, BLOCK_K):
|
|
79
|
+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
|
|
80
|
+
kh_mask = kh_idx < n_kh
|
|
81
|
+
|
|
82
|
+
block_mask = kh_mask[:, None] & d_mask[None, :]
|
|
83
|
+
offsets = kh_idx[:, None] * hd + d_idx[None, :]
|
|
84
|
+
|
|
85
|
+
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
|
|
86
|
+
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
87
|
+
|
|
88
|
+
if not BACKWARD_PASS:
|
|
89
|
+
new_left = k_left * cos_vals - k_right * sin_vals
|
|
90
|
+
new_right = k_right * cos_vals + k_left * sin_vals
|
|
91
|
+
else:
|
|
92
|
+
new_left = k_left * cos_vals + k_right * sin_vals
|
|
93
|
+
new_right = k_right * cos_vals - k_left * sin_vals
|
|
94
|
+
|
|
95
|
+
tl.store(k_base + offsets, new_left, mask=block_mask)
|
|
96
|
+
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
100
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
101
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
102
|
+
q = q.transpose(1, 2)
|
|
103
|
+
k = k.transpose(1, 2)
|
|
104
|
+
|
|
105
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
106
|
+
n_kv_head = k.shape[2]
|
|
107
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
108
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
109
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
110
|
+
|
|
111
|
+
n_row = batch_size * seq_len
|
|
112
|
+
|
|
113
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
114
|
+
q = q.contiguous()
|
|
115
|
+
k = k.contiguous()
|
|
116
|
+
cos = cos.contiguous()
|
|
117
|
+
sin = sin.contiguous()
|
|
118
|
+
|
|
119
|
+
# Compute tiling strategy based on UB capacity
|
|
120
|
+
dtype_size = q.element_size()
|
|
121
|
+
# MROPE forward tiling strategy:
|
|
122
|
+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
|
|
123
|
+
# - In q heads loop (peak memory):
|
|
124
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
125
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
126
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
127
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
128
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
129
|
+
# - In k heads loop (peak memory):
|
|
130
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
131
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
132
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
133
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
134
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
135
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
136
|
+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
|
|
137
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
|
|
138
|
+
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
|
|
139
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
140
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
141
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
142
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
143
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
144
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
145
|
+
safety_margin=0.90,
|
|
146
|
+
dtype_size=dtype_size,
|
|
147
|
+
memory_multiplier=3.0,
|
|
148
|
+
shapes=shapes,
|
|
149
|
+
tiling_dims=(0, 0),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
153
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
154
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
155
|
+
BLOCK_Q, _ = q_tile_shape
|
|
156
|
+
BLOCK_K, _ = k_tile_shape
|
|
157
|
+
else:
|
|
158
|
+
# Fallback to conservative defaults
|
|
159
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
160
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
161
|
+
_triton_qwen2vl_mrope_npu[(n_row,)](
|
|
162
|
+
q,
|
|
163
|
+
q.stride(1),
|
|
164
|
+
k,
|
|
165
|
+
k.stride(1),
|
|
166
|
+
cos,
|
|
167
|
+
sin,
|
|
168
|
+
seq_len,
|
|
169
|
+
batch_size,
|
|
170
|
+
n_q_head,
|
|
171
|
+
n_kv_head,
|
|
172
|
+
head_dim,
|
|
173
|
+
mrope_section[0],
|
|
174
|
+
mrope_section[1],
|
|
175
|
+
BLOCK_Q,
|
|
176
|
+
BLOCK_K,
|
|
177
|
+
BACKWARD_PASS=False,
|
|
178
|
+
)
|
|
179
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
|
183
|
+
dq = dq.transpose(1, 2)
|
|
184
|
+
dk = dk.transpose(1, 2)
|
|
185
|
+
|
|
186
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
187
|
+
n_kv_head = dk.shape[2]
|
|
188
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
189
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
190
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
191
|
+
|
|
192
|
+
n_row = batch_size * seq_len
|
|
193
|
+
|
|
194
|
+
# ensure dq and dk are contiguous
|
|
195
|
+
dq = dq.contiguous()
|
|
196
|
+
dk = dk.contiguous()
|
|
197
|
+
|
|
198
|
+
# Compute tiling strategy based on UB capacity
|
|
199
|
+
dtype_size = dq.element_size()
|
|
200
|
+
# MROPE backward tiling strategy:
|
|
201
|
+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
|
|
202
|
+
# - In q heads loop (peak memory):
|
|
203
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
204
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
205
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
206
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
207
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
208
|
+
# - In k heads loop (peak memory):
|
|
209
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
210
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
211
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
212
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
213
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
214
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
215
|
+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
|
|
216
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
|
|
217
|
+
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
|
|
218
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
219
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
220
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
221
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
222
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
223
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
224
|
+
safety_margin=0.90,
|
|
225
|
+
dtype_size=dtype_size,
|
|
226
|
+
memory_multiplier=3.0,
|
|
227
|
+
shapes=shapes,
|
|
228
|
+
tiling_dims=(0, 0),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
232
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
233
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
234
|
+
BLOCK_Q, _ = q_tile_shape
|
|
235
|
+
BLOCK_K, _ = k_tile_shape
|
|
236
|
+
else:
|
|
237
|
+
# Fallback to conservative defaults
|
|
238
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
239
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
240
|
+
_triton_qwen2vl_mrope_npu[(n_row,)](
|
|
241
|
+
dq,
|
|
242
|
+
dq.stride(1),
|
|
243
|
+
dk,
|
|
244
|
+
dk.stride(1),
|
|
245
|
+
cos,
|
|
246
|
+
sin,
|
|
247
|
+
seq_len,
|
|
248
|
+
batch_size,
|
|
249
|
+
n_q_head,
|
|
250
|
+
n_kv_head,
|
|
251
|
+
head_dim,
|
|
252
|
+
mrope_section[0],
|
|
253
|
+
mrope_section[1],
|
|
254
|
+
BLOCK_Q,
|
|
255
|
+
BLOCK_K,
|
|
256
|
+
BACKWARD_PASS=True,
|
|
257
|
+
)
|
|
258
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
262
|
+
@staticmethod
|
|
263
|
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
264
|
+
"""
|
|
265
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
266
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
267
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
268
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
269
|
+
"""
|
|
270
|
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
|
271
|
+
ctx.save_for_backward(cos, sin)
|
|
272
|
+
ctx.mrope_section = mrope_section
|
|
273
|
+
return q, k
|
|
274
|
+
|
|
275
|
+
def backward(ctx, dq, dk):
|
|
276
|
+
"""
|
|
277
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
278
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
279
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
280
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
281
|
+
"""
|
|
282
|
+
cos, sin = ctx.saved_tensors
|
|
283
|
+
mrope_section = ctx.mrope_section
|
|
284
|
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
|
285
|
+
return dq, dk, None, None, None, None
|
|
@@ -239,8 +239,8 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
239
239
|
BLOCK_K, _ = k_tile_shape
|
|
240
240
|
else:
|
|
241
241
|
# Fallback to conservative defaults
|
|
242
|
-
BLOCK_Q =
|
|
243
|
-
BLOCK_K =
|
|
242
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
243
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
244
244
|
|
|
245
245
|
_triton_rope_npu[(n_row,)](
|
|
246
246
|
dq,
|