liger-kernel-nightly 0.6.2.dev20251013144132__tar.gz → 0.6.2.dev20251014053719__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/PKG-INFO +1 -1
- liger_kernel_nightly-0.6.2.dev20251014053719/benchmark/scripts/benchmark_poly_norm.py +197 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/pyproject.toml +1 -1
- liger_kernel_nightly-0.6.2.dev20251014053719/src/liger_kernel/ops/poly_norm.py +386 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/__init__.py +2 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/functional.py +5 -0
- liger_kernel_nightly-0.6.2.dev20251014053719/src/liger_kernel/transformers/poly_norm.py +42 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
- liger_kernel_nightly-0.6.2.dev20251014053719/test/transformers/test_poly_norm.py +281 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/benchmark.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/.gitignore +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/Makefile +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/README.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_softmax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/modal/benchmarks.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/index.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/license.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/setup.cfg +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/setup.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/softmax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fsdp.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/falcon_h1.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/gemma3.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/glm4.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/glm4v.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/glm4v_moe.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/internvl.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/llama4.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen3.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/smollm3.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/softmax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_cosine_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_softmax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.2.dev20251013144132 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/utils.py +0 -0
@@ -0,0 +1,197 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import triton
|
4
|
+
|
5
|
+
from utils import QUANTILES
|
6
|
+
from utils import SingleBenchmarkRunInput
|
7
|
+
from utils import SingleBenchmarkRunOutput
|
8
|
+
from utils import _test_memory
|
9
|
+
from utils import parse_benchmark_script_args
|
10
|
+
from utils import run_benchmarks
|
11
|
+
|
12
|
+
from liger_kernel.transformers.poly_norm import LigerPolyNorm
|
13
|
+
from liger_kernel.utils import infer_device
|
14
|
+
|
15
|
+
device = infer_device()
|
16
|
+
|
17
|
+
|
18
|
+
class NaivePolyNorm(nn.Module):
|
19
|
+
"""
|
20
|
+
Naive PyTorch implementation of PolyNorm.
|
21
|
+
|
22
|
+
Reference:
|
23
|
+
https://github.com/BryceZhuo/PolyCom/
|
24
|
+
|
25
|
+
PolyNorm formula:
|
26
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
27
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, eps=1e-6):
|
31
|
+
super().__init__()
|
32
|
+
# Align with PolyCom reference: (1/3, 1/3, 1/3) and bias=1.0
|
33
|
+
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
|
34
|
+
self.bias = nn.Parameter(torch.tensor(1.0))
|
35
|
+
self.variance_epsilon = eps
|
36
|
+
|
37
|
+
def _norm(self, x):
|
38
|
+
"""RMSNorm operation"""
|
39
|
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
40
|
+
|
41
|
+
def forward(self, hidden_states):
|
42
|
+
"""
|
43
|
+
Forward pass of PolyNorm
|
44
|
+
|
45
|
+
Args:
|
46
|
+
hidden_states: input tensor of shape (..., H)
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
output tensor of same shape as input
|
50
|
+
"""
|
51
|
+
input_dtype = hidden_states.dtype
|
52
|
+
hidden_states = hidden_states.to(torch.float32)
|
53
|
+
|
54
|
+
# Compute powers
|
55
|
+
x_pow3 = hidden_states**3
|
56
|
+
x_pow2 = hidden_states**2
|
57
|
+
x_pow1 = hidden_states**1
|
58
|
+
|
59
|
+
# Normalize each power
|
60
|
+
norm_x3 = self._norm(x_pow3)
|
61
|
+
norm_x2 = self._norm(x_pow2)
|
62
|
+
norm_x1 = self._norm(x_pow1)
|
63
|
+
|
64
|
+
# Weighted sum with bias
|
65
|
+
output = self.weight[0] * norm_x3 + self.weight[1] * norm_x2 + self.weight[2] * norm_x1 + self.bias
|
66
|
+
|
67
|
+
return output.to(input_dtype)
|
68
|
+
|
69
|
+
|
70
|
+
def bench_speed_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
71
|
+
N = input.x
|
72
|
+
provider = input.kernel_provider
|
73
|
+
mode = input.kernel_operation_mode
|
74
|
+
|
75
|
+
extra_benchmark_config = input.extra_benchmark_config
|
76
|
+
M = extra_benchmark_config["M"]
|
77
|
+
eps = extra_benchmark_config["eps"]
|
78
|
+
dtype = extra_benchmark_config["dtype"]
|
79
|
+
|
80
|
+
x_shape = (M, N)
|
81
|
+
|
82
|
+
triton_poly = LigerPolyNorm(eps=eps).to(device)
|
83
|
+
naive_poly = NaivePolyNorm(eps=eps).to(device)
|
84
|
+
|
85
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
86
|
+
dy = torch.randn_like(x)
|
87
|
+
x.requires_grad_(True)
|
88
|
+
|
89
|
+
# utility functions
|
90
|
+
|
91
|
+
def y_fwd():
|
92
|
+
if provider == "liger":
|
93
|
+
return triton_poly(x)
|
94
|
+
|
95
|
+
if provider == "huggingface":
|
96
|
+
return naive_poly(x)
|
97
|
+
|
98
|
+
if mode == "forward":
|
99
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
100
|
+
y_fwd,
|
101
|
+
grad_to_none=[x],
|
102
|
+
rep=500,
|
103
|
+
quantiles=QUANTILES,
|
104
|
+
)
|
105
|
+
elif mode == "backward":
|
106
|
+
y = y_fwd()
|
107
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
108
|
+
lambda: y.backward(dy, retain_graph=True),
|
109
|
+
grad_to_none=[x],
|
110
|
+
rep=500,
|
111
|
+
quantiles=QUANTILES,
|
112
|
+
)
|
113
|
+
elif mode == "full":
|
114
|
+
|
115
|
+
def full():
|
116
|
+
y = y_fwd()
|
117
|
+
y.backward(dy, retain_graph=True)
|
118
|
+
|
119
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
120
|
+
full,
|
121
|
+
grad_to_none=[x],
|
122
|
+
rep=500,
|
123
|
+
quantiles=QUANTILES,
|
124
|
+
)
|
125
|
+
|
126
|
+
return SingleBenchmarkRunOutput(
|
127
|
+
y_20=ms_20,
|
128
|
+
y_50=ms_50,
|
129
|
+
y_80=ms_80,
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
def bench_memory_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
134
|
+
N = input.x
|
135
|
+
provider = input.kernel_provider
|
136
|
+
|
137
|
+
extra_benchmark_config = input.extra_benchmark_config
|
138
|
+
M = extra_benchmark_config["M"]
|
139
|
+
eps = extra_benchmark_config["eps"]
|
140
|
+
dtype = extra_benchmark_config["dtype"]
|
141
|
+
|
142
|
+
x_shape = (M, N)
|
143
|
+
|
144
|
+
triton_poly = LigerPolyNorm(eps=eps).to(device)
|
145
|
+
naive_poly = NaivePolyNorm(eps=eps).to(device)
|
146
|
+
|
147
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
148
|
+
dy = torch.randn_like(x)
|
149
|
+
x.requires_grad_(True)
|
150
|
+
|
151
|
+
# utility functions
|
152
|
+
def y_fwd():
|
153
|
+
if provider == "liger":
|
154
|
+
return triton_poly(x)
|
155
|
+
if provider == "huggingface":
|
156
|
+
return naive_poly(x)
|
157
|
+
|
158
|
+
def full():
|
159
|
+
y = y_fwd()
|
160
|
+
y.backward(dy, retain_graph=True)
|
161
|
+
|
162
|
+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
|
163
|
+
|
164
|
+
return SingleBenchmarkRunOutput(
|
165
|
+
y_20=mem_20,
|
166
|
+
y_50=mem_50,
|
167
|
+
y_80=mem_80,
|
168
|
+
)
|
169
|
+
|
170
|
+
|
171
|
+
if __name__ == "__main__":
|
172
|
+
args = parse_benchmark_script_args()
|
173
|
+
|
174
|
+
common_configs = {
|
175
|
+
"kernel_name": "poly_norm",
|
176
|
+
"x_name": "H",
|
177
|
+
"x_label": "hidden size",
|
178
|
+
"x_values": [2**i for i in range(10, 16)],
|
179
|
+
"kernel_providers": ["liger", "huggingface"],
|
180
|
+
"extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
|
181
|
+
"overwrite": args.overwrite,
|
182
|
+
}
|
183
|
+
|
184
|
+
run_benchmarks(
|
185
|
+
bench_test_fn=bench_speed_poly_norm,
|
186
|
+
kernel_operation_modes=["forward", "full", "backward"],
|
187
|
+
metric_name="speed",
|
188
|
+
metric_unit="ms",
|
189
|
+
**common_configs,
|
190
|
+
)
|
191
|
+
run_benchmarks(
|
192
|
+
bench_test_fn=bench_memory_poly_norm,
|
193
|
+
kernel_operation_modes=["full"],
|
194
|
+
metric_name="memory",
|
195
|
+
metric_unit="MB",
|
196
|
+
**common_configs,
|
197
|
+
)
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.6.2.
|
7
|
+
version = "0.6.2.dev20251014053719"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -0,0 +1,386 @@
|
|
1
|
+
import operator
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
10
|
+
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
12
|
+
try:
|
13
|
+
from triton.language.extra.libdevice import rsqrt
|
14
|
+
except ModuleNotFoundError:
|
15
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
16
|
+
else:
|
17
|
+
from triton.language.math import rsqrt
|
18
|
+
|
19
|
+
|
20
|
+
@triton.jit
|
21
|
+
def _poly_norm_forward_kernel(
|
22
|
+
Y_ptr,
|
23
|
+
Y_row_stride,
|
24
|
+
X_ptr,
|
25
|
+
X_row_stride,
|
26
|
+
W_ptr, # weight: [3] for [w0, w1, w2]
|
27
|
+
B_ptr, # bias: scalar
|
28
|
+
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
|
29
|
+
RSTD_row_stride,
|
30
|
+
n_cols,
|
31
|
+
eps,
|
32
|
+
BLOCK_SIZE: tl.constexpr,
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
PolyNorm formula:
|
36
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
37
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
38
|
+
|
39
|
+
Reference:
|
40
|
+
1. https://github.com/BryceZhuo/PolyCom/
|
41
|
+
2. https://arxiv.org/pdf/2411.03884
|
42
|
+
|
43
|
+
Cache rstd values for backward pass
|
44
|
+
"""
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
46
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
47
|
+
mask = col_offsets < n_cols
|
48
|
+
|
49
|
+
# Load pointers
|
50
|
+
Y_ptr += row_idx * Y_row_stride
|
51
|
+
X_ptr += row_idx * X_row_stride
|
52
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
53
|
+
|
54
|
+
# Load input row
|
55
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
56
|
+
|
57
|
+
# Load weights and bias
|
58
|
+
w0 = tl.load(W_ptr + 0)
|
59
|
+
w1 = tl.load(W_ptr + 1)
|
60
|
+
w2 = tl.load(W_ptr + 2)
|
61
|
+
b = tl.load(B_ptr)
|
62
|
+
|
63
|
+
# Compute x³, x², x
|
64
|
+
X_pow3 = X_row * X_row * X_row
|
65
|
+
X_pow2 = X_row * X_row
|
66
|
+
X_pow1 = X_row
|
67
|
+
|
68
|
+
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
|
69
|
+
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
|
70
|
+
rstd_3 = rsqrt(mean_square_3 + eps)
|
71
|
+
norm_x3 = X_pow3 * rstd_3
|
72
|
+
|
73
|
+
# Compute norm(x²)
|
74
|
+
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
|
75
|
+
rstd_2 = rsqrt(mean_square_2 + eps)
|
76
|
+
norm_x2 = X_pow2 * rstd_2
|
77
|
+
|
78
|
+
# Compute norm(x)
|
79
|
+
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
|
80
|
+
rstd_1 = rsqrt(mean_square_1 + eps)
|
81
|
+
norm_x1 = X_pow1 * rstd_1
|
82
|
+
|
83
|
+
# Cache rstd values for backward
|
84
|
+
tl.store(RSTD_ptr + 0, rstd_3)
|
85
|
+
tl.store(RSTD_ptr + 1, rstd_2)
|
86
|
+
tl.store(RSTD_ptr + 2, rstd_1)
|
87
|
+
|
88
|
+
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
89
|
+
Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
|
90
|
+
|
91
|
+
# Store output
|
92
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
93
|
+
|
94
|
+
|
95
|
+
@triton.jit
|
96
|
+
def _poly_norm_backward_kernel(
|
97
|
+
dY_ptr,
|
98
|
+
dY_row_stride,
|
99
|
+
dX_ptr,
|
100
|
+
dX_row_stride,
|
101
|
+
X_ptr,
|
102
|
+
X_row_stride,
|
103
|
+
W_ptr,
|
104
|
+
RSTD_ptr,
|
105
|
+
RSTD_row_stride,
|
106
|
+
dW_ptr, # shape: (n_programs, 3)
|
107
|
+
dW_row_stride,
|
108
|
+
dB_ptr, # shape: (n_programs,)
|
109
|
+
n_rows,
|
110
|
+
n_cols,
|
111
|
+
rows_per_program: tl.constexpr,
|
112
|
+
BLOCK_SIZE: tl.constexpr,
|
113
|
+
):
|
114
|
+
"""
|
115
|
+
PolyNorm Backward Kernel Gradient:
|
116
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
117
|
+
|
118
|
+
where:
|
119
|
+
- D_p = RMS(x^p) = 1/rstd_p
|
120
|
+
- S_p = sum(grad * x^p) over the row
|
121
|
+
- d = n_cols
|
122
|
+
- p ∈ {3, 2, 1}
|
123
|
+
"""
|
124
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
125
|
+
row_start = row_block_id * rows_per_program
|
126
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
127
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
128
|
+
mask = col_offsets < n_cols
|
129
|
+
|
130
|
+
# Initialize accumulators for weight and bias gradients (scalars)
|
131
|
+
dW0_acc = 0.0
|
132
|
+
dW1_acc = 0.0
|
133
|
+
dW2_acc = 0.0
|
134
|
+
dB_acc = 0.0
|
135
|
+
|
136
|
+
# Load weights
|
137
|
+
w0 = tl.load(W_ptr + 0).to(tl.float32)
|
138
|
+
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
139
|
+
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
140
|
+
|
141
|
+
dY_ptr += row_start * dY_row_stride
|
142
|
+
dX_ptr += row_start * dX_row_stride
|
143
|
+
X_ptr += row_start * X_row_stride
|
144
|
+
RSTD_ptr += row_start * RSTD_row_stride
|
145
|
+
|
146
|
+
for _ in range(row_start, row_end):
|
147
|
+
# Load input and gradient
|
148
|
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
149
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
150
|
+
|
151
|
+
# Load cached rstd values
|
152
|
+
rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
|
153
|
+
rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
|
154
|
+
rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
|
155
|
+
|
156
|
+
# Compute powers
|
157
|
+
X_pow3 = X_row * X_row * X_row
|
158
|
+
X_pow2 = X_row * X_row
|
159
|
+
X_pow1 = X_row
|
160
|
+
|
161
|
+
# Accumulate bias gradient: dB = sum(dY)
|
162
|
+
dB_acc += tl.sum(dY_row, axis=0)
|
163
|
+
|
164
|
+
# Compute gradient w.r.t. input using closed-form formula
|
165
|
+
# For p=3: ∂L/∂x from w0 * norm(x³)
|
166
|
+
S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
|
167
|
+
grad_x_3 = w0 * (
|
168
|
+
3.0 * X_pow2 * rstd_3 * dY_row
|
169
|
+
- (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
|
170
|
+
)
|
171
|
+
|
172
|
+
# For p=2: ∂L/∂x from w1 * norm(x²)
|
173
|
+
S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
|
174
|
+
grad_x_2 = w1 * (
|
175
|
+
2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
|
176
|
+
)
|
177
|
+
|
178
|
+
# For p=1: ∂L/∂x from w2 * norm(x)
|
179
|
+
S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
|
180
|
+
grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
|
181
|
+
|
182
|
+
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
|
183
|
+
dW0_acc += rstd_3 * S_3
|
184
|
+
dW1_acc += rstd_2 * S_2
|
185
|
+
dW2_acc += rstd_1 * S_1
|
186
|
+
|
187
|
+
# Total gradient
|
188
|
+
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
189
|
+
|
190
|
+
# Store gradient
|
191
|
+
tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
|
192
|
+
|
193
|
+
# Update pointers
|
194
|
+
dY_ptr += dY_row_stride
|
195
|
+
dX_ptr += dX_row_stride
|
196
|
+
X_ptr += X_row_stride
|
197
|
+
RSTD_ptr += RSTD_row_stride
|
198
|
+
|
199
|
+
# Store accumulated gradients (scalars)
|
200
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
201
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
|
202
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
|
203
|
+
tl.store(dB_ptr + row_block_id, dB_acc)
|
204
|
+
|
205
|
+
|
206
|
+
def poly_norm_forward(X, W, B, eps=1e-6):
|
207
|
+
"""
|
208
|
+
PolyNorm Forward Pass
|
209
|
+
|
210
|
+
Args:
|
211
|
+
X: input tensor of shape (*, H) where H is hidden dimension
|
212
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
213
|
+
B: bias scalar tensor
|
214
|
+
eps: epsilon for numerical stability
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
Y: output tensor of same shape as X
|
218
|
+
X: reshaped input (for backward)
|
219
|
+
RSTD: cached rstd values (for backward)
|
220
|
+
BLOCK_SIZE: block size used
|
221
|
+
num_warps: number of warps used
|
222
|
+
"""
|
223
|
+
shape = X.shape
|
224
|
+
dim = shape[-1]
|
225
|
+
X = X.view(-1, dim)
|
226
|
+
n_rows, n_cols = X.shape
|
227
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
228
|
+
|
229
|
+
# RSTD is to cache rstd for each row
|
230
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
231
|
+
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
|
232
|
+
|
233
|
+
# Check constraints
|
234
|
+
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
|
235
|
+
assert B.numel() == 1, "Bias must be a scalar"
|
236
|
+
|
237
|
+
# XPU-specific optimization
|
238
|
+
kernel_args = {}
|
239
|
+
if X.device.type == "xpu":
|
240
|
+
kernel_args["grf_mode"] = "large"
|
241
|
+
|
242
|
+
# Launch kernel
|
243
|
+
_poly_norm_forward_kernel[(n_rows,)](
|
244
|
+
Y,
|
245
|
+
Y.stride(0),
|
246
|
+
X,
|
247
|
+
X.stride(0),
|
248
|
+
W,
|
249
|
+
B,
|
250
|
+
RSTD,
|
251
|
+
RSTD.stride(0),
|
252
|
+
n_cols,
|
253
|
+
eps,
|
254
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
255
|
+
num_warps=num_warps,
|
256
|
+
**kernel_args,
|
257
|
+
)
|
258
|
+
|
259
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
|
260
|
+
|
261
|
+
|
262
|
+
def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
263
|
+
"""
|
264
|
+
PolyNorm Backward Pass
|
265
|
+
|
266
|
+
Args:
|
267
|
+
dY: gradient of output
|
268
|
+
X: input tensor (already reshaped to 2D)
|
269
|
+
W: weight tensor
|
270
|
+
RSTD: cached rstd values from forward
|
271
|
+
BLOCK_SIZE: block size from forward
|
272
|
+
num_warps: number of warps from forward
|
273
|
+
in_place: whether to in-place modify dY to store dX (saves memory)
|
274
|
+
|
275
|
+
Returns:
|
276
|
+
dX: gradient w.r.t. input
|
277
|
+
dW: gradient w.r.t. weight
|
278
|
+
dB: gradient w.r.t. bias
|
279
|
+
"""
|
280
|
+
shape = dY.shape
|
281
|
+
dim = shape[-1]
|
282
|
+
dY = dY.view(-1, dim)
|
283
|
+
n_rows, n_cols = dY.shape
|
284
|
+
|
285
|
+
# Get number of SMs for parallelization
|
286
|
+
import math
|
287
|
+
|
288
|
+
sm_count = 1
|
289
|
+
if X.device.type == "cuda":
|
290
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
291
|
+
elif X.device.type == "xpu":
|
292
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
293
|
+
|
294
|
+
# Allocate or reuse gradients
|
295
|
+
if in_place is True:
|
296
|
+
dX = dY
|
297
|
+
else:
|
298
|
+
dX = torch.zeros_like(dY)
|
299
|
+
|
300
|
+
_dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
|
301
|
+
_dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
|
302
|
+
|
303
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
304
|
+
grid = (sm_count,)
|
305
|
+
|
306
|
+
# XPU-specific optimization
|
307
|
+
kernel_args = {}
|
308
|
+
if X.device.type == "xpu":
|
309
|
+
kernel_args["grf_mode"] = "large"
|
310
|
+
|
311
|
+
# Launch backward kernel
|
312
|
+
_poly_norm_backward_kernel[grid](
|
313
|
+
dY,
|
314
|
+
dY.stride(0),
|
315
|
+
dX,
|
316
|
+
dX.stride(0),
|
317
|
+
X,
|
318
|
+
X.stride(0),
|
319
|
+
W,
|
320
|
+
RSTD,
|
321
|
+
RSTD.stride(0),
|
322
|
+
_dW,
|
323
|
+
_dW.stride(0),
|
324
|
+
_dB,
|
325
|
+
n_rows,
|
326
|
+
n_cols,
|
327
|
+
rows_per_program,
|
328
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
329
|
+
num_warps=num_warps,
|
330
|
+
**kernel_args,
|
331
|
+
)
|
332
|
+
|
333
|
+
# Reduce gradients across SMs
|
334
|
+
dX = dX.view(*shape)
|
335
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
336
|
+
dB = _dB.sum().to(W.dtype)
|
337
|
+
|
338
|
+
return dX, dW, dB
|
339
|
+
|
340
|
+
|
341
|
+
class LigerPolyNormFunction(torch.autograd.Function):
|
342
|
+
"""
|
343
|
+
PolyNorm Function with forward and backward pass
|
344
|
+
|
345
|
+
PolyNorm formula:
|
346
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
347
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
348
|
+
|
349
|
+
Backward uses closed-form gradient:
|
350
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
351
|
+
"""
|
352
|
+
|
353
|
+
@staticmethod
|
354
|
+
@ensure_contiguous
|
355
|
+
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
|
356
|
+
"""
|
357
|
+
Args:
|
358
|
+
X: input tensor of shape (B, T, H) or (BxT, H)
|
359
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
360
|
+
B: bias scalar
|
361
|
+
eps: epsilon for numerical stability
|
362
|
+
in_place: whether to in-place modify grad_output in backward (saves memory)
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
Y: output tensor of same shape as X
|
366
|
+
"""
|
367
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
|
368
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
369
|
+
ctx.num_warps = num_warps
|
370
|
+
ctx.in_place = in_place
|
371
|
+
ctx.save_for_backward(X, W, RSTD)
|
372
|
+
return Y
|
373
|
+
|
374
|
+
@staticmethod
|
375
|
+
@ensure_contiguous
|
376
|
+
def backward(ctx, grad_output):
|
377
|
+
"""
|
378
|
+
Args:
|
379
|
+
grad_output: gradient of output
|
380
|
+
|
381
|
+
Returns:
|
382
|
+
dX, dW, dB: gradients w.r.t. X, W, B
|
383
|
+
"""
|
384
|
+
X, W, RSTD = ctx.saved_tensors
|
385
|
+
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
|
386
|
+
return dX, dW, dB, None, None
|
@@ -15,6 +15,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
|
15
15
|
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
|
16
16
|
from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
|
17
17
|
from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
|
18
|
+
from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
|
18
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
19
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
20
21
|
from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
|
@@ -137,6 +138,7 @@ __all__ = [
|
|
137
138
|
"LigerJSD",
|
138
139
|
"LigerLayerNorm",
|
139
140
|
"LigerFusedAddRMSNorm",
|
141
|
+
"LigerPolyNorm",
|
140
142
|
"LigerRMSNorm",
|
141
143
|
"liger_rotary_pos_emb",
|
142
144
|
"liger_llama4_text_rotary_pos_emb",
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.jsd import LigerJSDFunction
|
|
12
12
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
13
13
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
14
14
|
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
15
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction
|
15
16
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
16
17
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
17
18
|
from liger_kernel.ops.rope import LigerRopeFunction
|
@@ -258,6 +259,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
|
|
258
259
|
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
259
260
|
|
260
261
|
|
262
|
+
def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
|
263
|
+
return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
|
264
|
+
|
265
|
+
|
261
266
|
def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
|
262
267
|
return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
|
263
268
|
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction
|
5
|
+
|
6
|
+
|
7
|
+
class LigerPolyNorm(nn.Module):
|
8
|
+
"""
|
9
|
+
PolyNorm layer wrapper for Liger kernel.
|
10
|
+
|
11
|
+
PolyNorm formula:
|
12
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
13
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
14
|
+
|
15
|
+
Reference:
|
16
|
+
https://github.com/BryceZhuo/PolyCom/
|
17
|
+
|
18
|
+
Args:
|
19
|
+
eps: epsilon for numerical stability (default: 1e-6)
|
20
|
+
in_place: whether to in-place modify grad_output in backward to save memory (default: False).
|
21
|
+
Set to True to save memory if grad_output is not needed elsewhere.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, eps=1e-6, in_place=True):
|
25
|
+
super().__init__()
|
26
|
+
# Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
|
27
|
+
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
|
28
|
+
self.bias = nn.Parameter(torch.tensor(1.0))
|
29
|
+
self.variance_epsilon = eps
|
30
|
+
self.in_place = in_place
|
31
|
+
|
32
|
+
def forward(self, hidden_states):
|
33
|
+
return LigerPolyNormFunction.apply(
|
34
|
+
hidden_states,
|
35
|
+
self.weight,
|
36
|
+
self.bias,
|
37
|
+
self.variance_epsilon,
|
38
|
+
self.in_place,
|
39
|
+
)
|
40
|
+
|
41
|
+
def extra_repr(self):
|
42
|
+
return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
|