liger-kernel-nightly 0.5.9.dev20250515034325__tar.gz → 0.5.9.dev20250515065336__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/data/all_benchmark_data.csv +72 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/benchmark/scripts/benchmark_sparsemax.py +172 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/pyproject.toml +1 -1
- liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/ops/sparsemax.py +167 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/functional.py +8 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/transformers/sparsemax.py +16 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/test/transformers/test_sparsemax.py +111 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.idea/workspace.xml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/Makefile +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/gema3_rms.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma3.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/glm4.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/utils.py +0 -0
@@ -805,3 +805,75 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.265
|
|
805
805
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
806
806
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
807
807
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
808
|
+
sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
809
|
+
sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
810
|
+
sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
811
|
+
sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
812
|
+
sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
813
|
+
sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
814
|
+
sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
815
|
+
sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
816
|
+
sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
817
|
+
sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
818
|
+
sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
819
|
+
sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
820
|
+
sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
821
|
+
sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
822
|
+
sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
823
|
+
sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
824
|
+
sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
825
|
+
sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
826
|
+
sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
827
|
+
sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
828
|
+
sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
829
|
+
sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
830
|
+
sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
831
|
+
sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
832
|
+
sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
833
|
+
sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
834
|
+
sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
835
|
+
sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
836
|
+
sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
837
|
+
sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
838
|
+
sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
839
|
+
sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
840
|
+
sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
841
|
+
sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
842
|
+
sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
843
|
+
sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
844
|
+
sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
845
|
+
sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
846
|
+
sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
847
|
+
sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
848
|
+
sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
849
|
+
sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
850
|
+
sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
851
|
+
sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
852
|
+
sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
853
|
+
sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
854
|
+
sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
855
|
+
sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
856
|
+
sparsemax,liger,forward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
857
|
+
sparsemax,liger,forward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
858
|
+
sparsemax,liger,forward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
859
|
+
sparsemax,liger,forward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
860
|
+
sparsemax,liger,forward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
861
|
+
sparsemax,liger,forward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
862
|
+
sparsemax,torch,forward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
863
|
+
sparsemax,torch,forward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
864
|
+
sparsemax,torch,forward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
865
|
+
sparsemax,torch,forward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
866
|
+
sparsemax,torch,forward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
867
|
+
sparsemax,torch,forward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
868
|
+
sparsemax,liger,backward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
869
|
+
sparsemax,liger,backward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
870
|
+
sparsemax,liger,backward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
871
|
+
sparsemax,liger,backward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
872
|
+
sparsemax,liger,backward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
873
|
+
sparsemax,liger,backward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
874
|
+
sparsemax,torch,backward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
875
|
+
sparsemax,torch,backward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
876
|
+
sparsemax,torch,backward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
877
|
+
sparsemax,torch,backward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
878
|
+
sparsemax,torch,backward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
879
|
+
sparsemax,torch,backward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
@@ -0,0 +1,172 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
|
4
|
+
from utils import QUANTILES
|
5
|
+
from utils import SingleBenchmarkRunInput
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
7
|
+
from utils import _test_memory
|
8
|
+
from utils import parse_benchmark_script_args
|
9
|
+
from utils import run_benchmarks
|
10
|
+
|
11
|
+
from liger_kernel.transformers.sparsemax import LigerSparsemax
|
12
|
+
from liger_kernel.utils import infer_device
|
13
|
+
|
14
|
+
device = infer_device()
|
15
|
+
|
16
|
+
|
17
|
+
def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
18
|
+
input_dims = input_tensor.dim()
|
19
|
+
if dim < 0:
|
20
|
+
dim = input_dims + dim
|
21
|
+
input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
|
22
|
+
cumsum_input = torch.cumsum(input_sorted, dim=dim)
|
23
|
+
input_size = input_tensor.size(dim)
|
24
|
+
range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
|
25
|
+
shape = [1] * input_dims
|
26
|
+
shape[dim] = input_size
|
27
|
+
range_tensor = range_tensor.view(shape)
|
28
|
+
k_bound = 1 + range_tensor * input_sorted
|
29
|
+
support = k_bound > cumsum_input
|
30
|
+
k = support.sum(dim=dim, keepdim=True).clamp(min=1)
|
31
|
+
support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
|
32
|
+
tau = (support_sum - 1) / k
|
33
|
+
return torch.clamp(input_tensor - tau, min=0)
|
34
|
+
|
35
|
+
|
36
|
+
class TorchSparsemax(torch.nn.Module):
|
37
|
+
def __init__(self, dim: int = -1):
|
38
|
+
super().__init__()
|
39
|
+
self.dim = dim
|
40
|
+
|
41
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42
|
+
return torch_sparsemax(x, dim=self.dim)
|
43
|
+
|
44
|
+
|
45
|
+
def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
46
|
+
V = input.x
|
47
|
+
provider = input.kernel_provider
|
48
|
+
mode = input.kernel_operation_mode
|
49
|
+
|
50
|
+
extra_benchmark_config = input.extra_benchmark_config
|
51
|
+
B = extra_benchmark_config["B"]
|
52
|
+
T = extra_benchmark_config["T"]
|
53
|
+
dim = extra_benchmark_config["dim"]
|
54
|
+
dtype = extra_benchmark_config["dtype"]
|
55
|
+
|
56
|
+
x_shape = (B * T, V)
|
57
|
+
|
58
|
+
torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
|
59
|
+
liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
|
60
|
+
|
61
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
62
|
+
dy = torch.randn_like(x)
|
63
|
+
x.requires_grad_(True)
|
64
|
+
|
65
|
+
# utility functions
|
66
|
+
def y_fwd():
|
67
|
+
if provider == "liger":
|
68
|
+
return liger_sparsemax_module(x)
|
69
|
+
elif provider == "torch":
|
70
|
+
return torch_sparsemax_module(x)
|
71
|
+
|
72
|
+
if mode == "forward":
|
73
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
74
|
+
y_fwd,
|
75
|
+
grad_to_none=[x],
|
76
|
+
rep=500,
|
77
|
+
quantiles=QUANTILES,
|
78
|
+
)
|
79
|
+
elif mode == "backward":
|
80
|
+
y = y_fwd()
|
81
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
82
|
+
lambda: y.backward(dy, retain_graph=True),
|
83
|
+
grad_to_none=[x],
|
84
|
+
rep=500,
|
85
|
+
quantiles=QUANTILES,
|
86
|
+
)
|
87
|
+
elif mode == "full":
|
88
|
+
|
89
|
+
def full():
|
90
|
+
y = y_fwd()
|
91
|
+
y.backward(dy, retain_graph=True)
|
92
|
+
|
93
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
94
|
+
full,
|
95
|
+
grad_to_none=[x],
|
96
|
+
rep=500,
|
97
|
+
quantiles=QUANTILES,
|
98
|
+
)
|
99
|
+
|
100
|
+
return SingleBenchmarkRunOutput(
|
101
|
+
y_20=ms_20,
|
102
|
+
y_50=ms_50,
|
103
|
+
y_80=ms_80,
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
108
|
+
V = input.x
|
109
|
+
provider = input.kernel_provider
|
110
|
+
|
111
|
+
extra_benchmark_config = input.extra_benchmark_config
|
112
|
+
B = extra_benchmark_config["B"]
|
113
|
+
T = extra_benchmark_config["T"]
|
114
|
+
dim = extra_benchmark_config["dim"]
|
115
|
+
dtype = extra_benchmark_config["dtype"]
|
116
|
+
|
117
|
+
x_shape = (B * T, V)
|
118
|
+
|
119
|
+
torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
|
120
|
+
liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
|
121
|
+
|
122
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
123
|
+
dy = torch.randn_like(x)
|
124
|
+
x.requires_grad_(True)
|
125
|
+
|
126
|
+
# utility functions
|
127
|
+
def y_fwd():
|
128
|
+
if provider == "liger":
|
129
|
+
return liger_sparsemax_module(x)
|
130
|
+
elif provider == "torch":
|
131
|
+
return torch_sparsemax_module(x)
|
132
|
+
|
133
|
+
def full():
|
134
|
+
y = y_fwd()
|
135
|
+
y.backward(dy, retain_graph=True)
|
136
|
+
|
137
|
+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
|
138
|
+
|
139
|
+
return SingleBenchmarkRunOutput(
|
140
|
+
y_20=mem_20,
|
141
|
+
y_50=mem_50,
|
142
|
+
y_80=mem_80,
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
args = parse_benchmark_script_args()
|
148
|
+
|
149
|
+
common_configs = {
|
150
|
+
"kernel_name": "sparsemax",
|
151
|
+
"x_name": "V",
|
152
|
+
"x_label": "feature size",
|
153
|
+
"x_values": [2**i for i in range(10, 16)],
|
154
|
+
"kernel_providers": ["liger", "torch"],
|
155
|
+
"extra_benchmark_configs": [{"B": 4, "T": 512, "dim": -1, "dtype": torch.float32}],
|
156
|
+
"overwrite": args.overwrite,
|
157
|
+
}
|
158
|
+
|
159
|
+
run_benchmarks(
|
160
|
+
bench_test_fn=bench_speed_sparsemax,
|
161
|
+
kernel_operation_modes=["forward", "full", "backward"],
|
162
|
+
metric_name="speed",
|
163
|
+
metric_unit="ms",
|
164
|
+
**common_configs,
|
165
|
+
)
|
166
|
+
run_benchmarks(
|
167
|
+
bench_test_fn=bench_memory_sparsemax,
|
168
|
+
kernel_operation_modes=["full"],
|
169
|
+
metric_name="memory",
|
170
|
+
metric_unit="MB",
|
171
|
+
**common_configs,
|
172
|
+
)
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.9.
|
7
|
+
version = "0.5.9.dev20250515065336"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -0,0 +1,167 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
7
|
+
|
8
|
+
|
9
|
+
@triton.jit
|
10
|
+
def _sparsemax_forward_kernel(
|
11
|
+
x_ptr,
|
12
|
+
x_stride_row,
|
13
|
+
sorted_x_ptr,
|
14
|
+
sorted_x_stride_row,
|
15
|
+
o_ptr,
|
16
|
+
o_stride_row,
|
17
|
+
n_cols,
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
19
|
+
num_warps: tl.constexpr,
|
20
|
+
):
|
21
|
+
pid_row = tl.program_id(0)
|
22
|
+
ptr_x_data_row = x_ptr + pid_row * x_stride_row
|
23
|
+
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
|
24
|
+
ptr_output_row = o_ptr + pid_row * o_stride_row
|
25
|
+
|
26
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
27
|
+
mask = offs < n_cols
|
28
|
+
|
29
|
+
z_sorted_block = tl.load(
|
30
|
+
ptr_sorted_x_data_row + offs,
|
31
|
+
mask=mask,
|
32
|
+
other=-float("inf"),
|
33
|
+
cache_modifier=".ca",
|
34
|
+
).to(tl.float32)
|
35
|
+
|
36
|
+
z_valid = tl.where(mask, z_sorted_block, 0.0)
|
37
|
+
cssv = tl.cumsum(z_valid, 0)
|
38
|
+
|
39
|
+
r = (offs + 1).to(tl.float32)
|
40
|
+
safe_r = tl.where(mask, r, 1.0)
|
41
|
+
|
42
|
+
t_vec = (cssv - 1.0) / safe_r
|
43
|
+
|
44
|
+
support = (z_sorted_block > t_vec) & mask
|
45
|
+
|
46
|
+
k_int = tl.sum(support.to(tl.int32), 0)
|
47
|
+
k_clamped_int = tl.maximum(k_int, 1)
|
48
|
+
k = k_clamped_int.to(tl.float32)
|
49
|
+
|
50
|
+
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
|
51
|
+
|
52
|
+
tau = (s - 1.0) / k
|
53
|
+
|
54
|
+
x_block = tl.load(
|
55
|
+
ptr_x_data_row + offs,
|
56
|
+
mask=mask,
|
57
|
+
other=0.0,
|
58
|
+
cache_modifier=".ca",
|
59
|
+
).to(tl.float32)
|
60
|
+
|
61
|
+
y = tl.maximum(x_block - tau, 0.0)
|
62
|
+
|
63
|
+
tl.store(
|
64
|
+
ptr_output_row + offs,
|
65
|
+
y.to(ptr_output_row.dtype.element_ty),
|
66
|
+
mask=mask,
|
67
|
+
cache_modifier=".cs",
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@triton.jit
|
72
|
+
def _sparsemax_backward_kernel(
|
73
|
+
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
|
74
|
+
):
|
75
|
+
row = tl.program_id(0)
|
76
|
+
o_row = o_ptr + row * stride
|
77
|
+
go_row = go_ptr + row * stride
|
78
|
+
gi_row = gi_ptr + row * stride
|
79
|
+
|
80
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
81
|
+
|
82
|
+
supp_cnt = tl.zeros((), tl.float32)
|
83
|
+
go_sum = tl.zeros((), tl.float32)
|
84
|
+
|
85
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
86
|
+
offs_iter = i * BLOCK_SIZE + offs
|
87
|
+
mask_iter = offs_iter < n_cols
|
88
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
89
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
90
|
+
supp = o_val > 0.0
|
91
|
+
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
|
92
|
+
supp_cnt += tl.sum(supp.to(tl.float32))
|
93
|
+
|
94
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
95
|
+
offs_iter = i * BLOCK_SIZE + offs
|
96
|
+
mask_iter = offs_iter < n_cols
|
97
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
98
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
99
|
+
supp = o_val > 0.0
|
100
|
+
gi_val = tl.where(
|
101
|
+
supp,
|
102
|
+
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
|
103
|
+
0.0,
|
104
|
+
)
|
105
|
+
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
106
|
+
|
107
|
+
|
108
|
+
class LigerSparsemaxFunction(torch.autograd.Function):
|
109
|
+
@staticmethod
|
110
|
+
@ensure_contiguous
|
111
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
112
|
+
if dim < 0:
|
113
|
+
dim += x.dim()
|
114
|
+
ctx.dim = dim
|
115
|
+
|
116
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
117
|
+
n_cols = x_sw.size(-1)
|
118
|
+
n_rows = x_sw.numel() // n_cols
|
119
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
120
|
+
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
122
|
+
out_flat = torch.empty_like(x_flat)
|
123
|
+
grid = (n_rows,)
|
124
|
+
|
125
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
126
|
+
|
127
|
+
_sparsemax_forward_kernel[grid](
|
128
|
+
x_flat,
|
129
|
+
x_flat.stride(0),
|
130
|
+
x_sorted_flat,
|
131
|
+
x_sorted_flat.stride(0),
|
132
|
+
out_flat,
|
133
|
+
out_flat.stride(0),
|
134
|
+
n_cols,
|
135
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
136
|
+
num_warps=num_warps,
|
137
|
+
)
|
138
|
+
|
139
|
+
ctx.save_for_backward(out_flat)
|
140
|
+
return out_flat.view_as(x_sw).transpose(dim, -1)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
@ensure_contiguous
|
144
|
+
def backward(ctx, grad_out: torch.Tensor):
|
145
|
+
(out_flat,) = ctx.saved_tensors
|
146
|
+
dim = ctx.dim
|
147
|
+
|
148
|
+
go_sw = grad_out.transpose(dim, -1).contiguous()
|
149
|
+
n_cols = go_sw.size(-1)
|
150
|
+
n_rows = go_sw.numel() // n_cols
|
151
|
+
go_flat = go_sw.view(n_rows, n_cols)
|
152
|
+
|
153
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
154
|
+
gi_flat = torch.empty_like(go_flat)
|
155
|
+
grid = (n_rows,)
|
156
|
+
|
157
|
+
_sparsemax_backward_kernel[grid](
|
158
|
+
out_flat,
|
159
|
+
go_flat,
|
160
|
+
gi_flat,
|
161
|
+
out_flat.stride(0),
|
162
|
+
n_cols,
|
163
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
164
|
+
num_warps=num_warps,
|
165
|
+
)
|
166
|
+
|
167
|
+
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
12
12
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
13
13
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
14
14
|
from liger_kernel.ops.rope import LigerRopeFunction
|
15
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
15
16
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
16
17
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
17
18
|
|
@@ -159,6 +160,13 @@ def liger_kl_div(
|
|
159
160
|
)
|
160
161
|
|
161
162
|
|
163
|
+
def liger_sparsemax(
|
164
|
+
input,
|
165
|
+
dim: int = -1,
|
166
|
+
):
|
167
|
+
return LigerSparsemaxFunction.apply(input, dim)
|
168
|
+
|
169
|
+
|
162
170
|
def liger_tvd(
|
163
171
|
input,
|
164
172
|
target,
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
5
|
+
|
6
|
+
|
7
|
+
class LigerSparsemax(nn.Module):
|
8
|
+
def __init__(self, dim: int = -1):
|
9
|
+
super().__init__()
|
10
|
+
self.dim = dim
|
11
|
+
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
14
|
+
|
15
|
+
def extra_repr(self) -> str:
|
16
|
+
return f"dim={self.dim}"
|
@@ -40,6 +40,7 @@ benchmark/scripts/benchmark_qwen2vl_mrope.py
|
|
40
40
|
benchmark/scripts/benchmark_rms_norm.py
|
41
41
|
benchmark/scripts/benchmark_rope.py
|
42
42
|
benchmark/scripts/benchmark_simpo_loss.py
|
43
|
+
benchmark/scripts/benchmark_sparsemax.py
|
43
44
|
benchmark/scripts/benchmark_swiglu.py
|
44
45
|
benchmark/scripts/benchmark_tvd.py
|
45
46
|
benchmark/scripts/utils.py
|
@@ -134,6 +135,7 @@ src/liger_kernel/ops/layer_norm.py
|
|
134
135
|
src/liger_kernel/ops/qwen2vl_mrope.py
|
135
136
|
src/liger_kernel/ops/rms_norm.py
|
136
137
|
src/liger_kernel/ops/rope.py
|
138
|
+
src/liger_kernel/ops/sparsemax.py
|
137
139
|
src/liger_kernel/ops/swiglu.py
|
138
140
|
src/liger_kernel/ops/tvd.py
|
139
141
|
src/liger_kernel/ops/utils.py
|
@@ -156,6 +158,7 @@ src/liger_kernel/transformers/monkey_patch.py
|
|
156
158
|
src/liger_kernel/transformers/qwen2vl_mrope.py
|
157
159
|
src/liger_kernel/transformers/rms_norm.py
|
158
160
|
src/liger_kernel/transformers/rope.py
|
161
|
+
src/liger_kernel/transformers/sparsemax.py
|
159
162
|
src/liger_kernel/transformers/swiglu.py
|
160
163
|
src/liger_kernel/transformers/trainer_integration.py
|
161
164
|
src/liger_kernel/transformers/tvd.py
|
@@ -238,6 +241,7 @@ test/transformers/test_monkey_patch.py
|
|
238
241
|
test/transformers/test_qwen2vl_mrope.py
|
239
242
|
test/transformers/test_rms_norm.py
|
240
243
|
test/transformers/test_rope.py
|
244
|
+
test/transformers/test_sparsemax.py
|
241
245
|
test/transformers/test_swiglu.py
|
242
246
|
test/transformers/test_trainer_integration.py
|
243
247
|
test/transformers/test_transformers.py
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
|
4
|
+
from test.utils import assert_verbose_allclose
|
5
|
+
from test.utils import set_seed
|
6
|
+
|
7
|
+
from liger_kernel.transformers.functional import liger_sparsemax
|
8
|
+
from liger_kernel.transformers.sparsemax import LigerSparsemax
|
9
|
+
from liger_kernel.utils import infer_device
|
10
|
+
|
11
|
+
device = infer_device()
|
12
|
+
|
13
|
+
|
14
|
+
def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
15
|
+
input_dims = input_tensor.dim()
|
16
|
+
if dim < 0:
|
17
|
+
dim = input_dims + dim
|
18
|
+
input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
|
19
|
+
cumsum_input = torch.cumsum(input_sorted, dim=dim)
|
20
|
+
input_size = input_tensor.size(dim)
|
21
|
+
range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
|
22
|
+
shape = [1] * input_dims
|
23
|
+
shape[dim] = input_size
|
24
|
+
range_tensor = range_tensor.view(shape)
|
25
|
+
k_bound = 1 + range_tensor * input_sorted
|
26
|
+
support = k_bound > cumsum_input
|
27
|
+
k = support.sum(dim=dim, keepdim=True).clamp(min=1)
|
28
|
+
support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
|
29
|
+
tau = (support_sum - 1) / k
|
30
|
+
return torch.clamp(input_tensor - tau, min=0)
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.mark.parametrize(
|
34
|
+
"batch_size, seq_len, features",
|
35
|
+
[
|
36
|
+
(2, 128, 512),
|
37
|
+
(5, 123, 123),
|
38
|
+
],
|
39
|
+
)
|
40
|
+
@pytest.mark.parametrize("dim", [-1, 1])
|
41
|
+
@pytest.mark.parametrize(
|
42
|
+
"dtype, atol, rtol",
|
43
|
+
[(torch.float32, 1e-5, 1e-5)],
|
44
|
+
)
|
45
|
+
def test_liger_sparsemax_correctness(batch_size, seq_len, features, dim, dtype, atol, rtol):
|
46
|
+
set_seed(0)
|
47
|
+
shape = (batch_size, seq_len, features)
|
48
|
+
if dim >= len(shape) or dim < -len(shape):
|
49
|
+
pytest.skip("invalid dim")
|
50
|
+
if shape[dim if dim >= 0 else len(shape) + dim] <= 1:
|
51
|
+
pytest.skip("trivial dim")
|
52
|
+
|
53
|
+
x = torch.randn(*shape, dtype=dtype, device=device)
|
54
|
+
lx = x.clone().requires_grad_(True)
|
55
|
+
tx = x.clone().requires_grad_(True)
|
56
|
+
|
57
|
+
model = LigerSparsemax(dim=dim).to(device)
|
58
|
+
out_l = model(lx)
|
59
|
+
out_t = torch_sparsemax(tx, dim=dim)
|
60
|
+
assert_verbose_allclose(out_l, out_t, atol=atol, rtol=rtol)
|
61
|
+
|
62
|
+
sum_l = out_l.sum(dim=dim)
|
63
|
+
sum_t = out_t.sum(dim=dim)
|
64
|
+
assert_verbose_allclose(sum_l, torch.ones_like(sum_l), atol=atol * 10, rtol=rtol * 10)
|
65
|
+
assert_verbose_allclose(sum_t, torch.ones_like(sum_t), atol=atol * 10, rtol=rtol * 10)
|
66
|
+
|
67
|
+
g = torch.randn_like(x)
|
68
|
+
out_l.backward(g)
|
69
|
+
out_t.backward(g)
|
70
|
+
assert_verbose_allclose(lx.grad, tx.grad, atol=atol, rtol=rtol)
|
71
|
+
|
72
|
+
|
73
|
+
@pytest.mark.parametrize(
|
74
|
+
"batch_size, seq_len, features",
|
75
|
+
[
|
76
|
+
(2, 128, 512),
|
77
|
+
(5, 123, 123),
|
78
|
+
],
|
79
|
+
)
|
80
|
+
@pytest.mark.parametrize("dim", [-1, 1])
|
81
|
+
@pytest.mark.parametrize(
|
82
|
+
"dtype, atol, rtol",
|
83
|
+
[
|
84
|
+
(torch.float32, 1e-5, 1e-5),
|
85
|
+
],
|
86
|
+
)
|
87
|
+
def test_liger_sparsemax_functional_correctness(batch_size, seq_len, features, dim, dtype, atol, rtol):
|
88
|
+
set_seed(0)
|
89
|
+
shape = (batch_size, seq_len, features)
|
90
|
+
if dim >= len(shape) or dim < -len(shape):
|
91
|
+
pytest.skip("invalid dim")
|
92
|
+
if shape[dim if dim >= 0 else len(shape) + dim] <= 1:
|
93
|
+
pytest.skip("trivial dim")
|
94
|
+
|
95
|
+
x = torch.randn(*shape, dtype=dtype, device=device)
|
96
|
+
lx = x.clone().requires_grad_(True)
|
97
|
+
tx = x.clone().requires_grad_(True)
|
98
|
+
|
99
|
+
out_l = liger_sparsemax(lx, dim=dim)
|
100
|
+
out_t = torch_sparsemax(tx, dim=dim)
|
101
|
+
assert_verbose_allclose(out_l, out_t, atol=atol, rtol=rtol)
|
102
|
+
|
103
|
+
sum_l = out_l.sum(dim=dim)
|
104
|
+
sum_t = out_t.sum(dim=dim)
|
105
|
+
assert_verbose_allclose(sum_l, torch.ones_like(sum_l), atol=atol * 10, rtol=rtol * 10)
|
106
|
+
assert_verbose_allclose(sum_t, torch.ones_like(sum_t), atol=atol * 10, rtol=rtol * 10)
|
107
|
+
|
108
|
+
g = torch.randn_like(x)
|
109
|
+
out_l.backward(g)
|
110
|
+
out_t.backward(g)
|
111
|
+
assert_verbose_allclose(lx.grad, tx.grad, atol=atol, rtol=rtol)
|
File without changes
|
File without changes
|
File without changes
|