liger-kernel-nightly 0.5.9.dev20250512213150__tar.gz → 0.5.9.dev20250515065336__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/data/all_benchmark_data.csv +72 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/benchmark/scripts/benchmark_sparsemax.py +172 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/examples/medusa/requirements.txt +3 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/scripts/llama3_8b_medusa.sh +2 -5
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/train.py +36 -38
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/pyproject.toml +1 -1
- liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/ops/sparsemax.py +167 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/functional.py +8 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/transformers/sparsemax.py +16 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
- liger_kernel_nightly-0.5.9.dev20250515065336/test/transformers/test_sparsemax.py +111 -0
- liger_kernel_nightly-0.5.9.dev20250512213150/examples/medusa/requirements.txt +0 -3
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.idea/workspace.xml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/Makefile +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/gema3_rms.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma3.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/glm4.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/utils.py +0 -0
@@ -805,3 +805,75 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.265
|
|
805
805
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
806
806
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
807
807
|
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
808
|
+
sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
809
|
+
sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
810
|
+
sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
811
|
+
sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
812
|
+
sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
813
|
+
sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
|
814
|
+
sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
815
|
+
sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
816
|
+
sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
817
|
+
sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
818
|
+
sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
819
|
+
sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
|
820
|
+
sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
821
|
+
sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
822
|
+
sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
823
|
+
sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
824
|
+
sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
825
|
+
sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
|
826
|
+
sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
827
|
+
sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
828
|
+
sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
829
|
+
sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
830
|
+
sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
831
|
+
sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
|
832
|
+
sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
833
|
+
sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
834
|
+
sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
835
|
+
sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
836
|
+
sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
837
|
+
sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
|
838
|
+
sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
839
|
+
sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
840
|
+
sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
841
|
+
sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
842
|
+
sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
843
|
+
sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
844
|
+
sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
845
|
+
sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
846
|
+
sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
847
|
+
sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
848
|
+
sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
849
|
+
sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
|
850
|
+
sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
851
|
+
sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
852
|
+
sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
853
|
+
sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
854
|
+
sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
855
|
+
sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
|
856
|
+
sparsemax,liger,forward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
857
|
+
sparsemax,liger,forward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
858
|
+
sparsemax,liger,forward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
859
|
+
sparsemax,liger,forward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
860
|
+
sparsemax,liger,forward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
861
|
+
sparsemax,liger,forward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
862
|
+
sparsemax,torch,forward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
863
|
+
sparsemax,torch,forward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
864
|
+
sparsemax,torch,forward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
865
|
+
sparsemax,torch,forward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
866
|
+
sparsemax,torch,forward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
867
|
+
sparsemax,torch,forward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
|
868
|
+
sparsemax,liger,backward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
869
|
+
sparsemax,liger,backward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
870
|
+
sparsemax,liger,backward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
871
|
+
sparsemax,liger,backward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
872
|
+
sparsemax,liger,backward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
873
|
+
sparsemax,liger,backward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
|
874
|
+
sparsemax,torch,backward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
875
|
+
sparsemax,torch,backward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
876
|
+
sparsemax,torch,backward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
877
|
+
sparsemax,torch,backward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
878
|
+
sparsemax,torch,backward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
879
|
+
sparsemax,torch,backward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
|
@@ -0,0 +1,172 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
|
4
|
+
from utils import QUANTILES
|
5
|
+
from utils import SingleBenchmarkRunInput
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
7
|
+
from utils import _test_memory
|
8
|
+
from utils import parse_benchmark_script_args
|
9
|
+
from utils import run_benchmarks
|
10
|
+
|
11
|
+
from liger_kernel.transformers.sparsemax import LigerSparsemax
|
12
|
+
from liger_kernel.utils import infer_device
|
13
|
+
|
14
|
+
device = infer_device()
|
15
|
+
|
16
|
+
|
17
|
+
def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
18
|
+
input_dims = input_tensor.dim()
|
19
|
+
if dim < 0:
|
20
|
+
dim = input_dims + dim
|
21
|
+
input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
|
22
|
+
cumsum_input = torch.cumsum(input_sorted, dim=dim)
|
23
|
+
input_size = input_tensor.size(dim)
|
24
|
+
range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
|
25
|
+
shape = [1] * input_dims
|
26
|
+
shape[dim] = input_size
|
27
|
+
range_tensor = range_tensor.view(shape)
|
28
|
+
k_bound = 1 + range_tensor * input_sorted
|
29
|
+
support = k_bound > cumsum_input
|
30
|
+
k = support.sum(dim=dim, keepdim=True).clamp(min=1)
|
31
|
+
support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
|
32
|
+
tau = (support_sum - 1) / k
|
33
|
+
return torch.clamp(input_tensor - tau, min=0)
|
34
|
+
|
35
|
+
|
36
|
+
class TorchSparsemax(torch.nn.Module):
|
37
|
+
def __init__(self, dim: int = -1):
|
38
|
+
super().__init__()
|
39
|
+
self.dim = dim
|
40
|
+
|
41
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42
|
+
return torch_sparsemax(x, dim=self.dim)
|
43
|
+
|
44
|
+
|
45
|
+
def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
46
|
+
V = input.x
|
47
|
+
provider = input.kernel_provider
|
48
|
+
mode = input.kernel_operation_mode
|
49
|
+
|
50
|
+
extra_benchmark_config = input.extra_benchmark_config
|
51
|
+
B = extra_benchmark_config["B"]
|
52
|
+
T = extra_benchmark_config["T"]
|
53
|
+
dim = extra_benchmark_config["dim"]
|
54
|
+
dtype = extra_benchmark_config["dtype"]
|
55
|
+
|
56
|
+
x_shape = (B * T, V)
|
57
|
+
|
58
|
+
torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
|
59
|
+
liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
|
60
|
+
|
61
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
62
|
+
dy = torch.randn_like(x)
|
63
|
+
x.requires_grad_(True)
|
64
|
+
|
65
|
+
# utility functions
|
66
|
+
def y_fwd():
|
67
|
+
if provider == "liger":
|
68
|
+
return liger_sparsemax_module(x)
|
69
|
+
elif provider == "torch":
|
70
|
+
return torch_sparsemax_module(x)
|
71
|
+
|
72
|
+
if mode == "forward":
|
73
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
74
|
+
y_fwd,
|
75
|
+
grad_to_none=[x],
|
76
|
+
rep=500,
|
77
|
+
quantiles=QUANTILES,
|
78
|
+
)
|
79
|
+
elif mode == "backward":
|
80
|
+
y = y_fwd()
|
81
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
82
|
+
lambda: y.backward(dy, retain_graph=True),
|
83
|
+
grad_to_none=[x],
|
84
|
+
rep=500,
|
85
|
+
quantiles=QUANTILES,
|
86
|
+
)
|
87
|
+
elif mode == "full":
|
88
|
+
|
89
|
+
def full():
|
90
|
+
y = y_fwd()
|
91
|
+
y.backward(dy, retain_graph=True)
|
92
|
+
|
93
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
94
|
+
full,
|
95
|
+
grad_to_none=[x],
|
96
|
+
rep=500,
|
97
|
+
quantiles=QUANTILES,
|
98
|
+
)
|
99
|
+
|
100
|
+
return SingleBenchmarkRunOutput(
|
101
|
+
y_20=ms_20,
|
102
|
+
y_50=ms_50,
|
103
|
+
y_80=ms_80,
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
108
|
+
V = input.x
|
109
|
+
provider = input.kernel_provider
|
110
|
+
|
111
|
+
extra_benchmark_config = input.extra_benchmark_config
|
112
|
+
B = extra_benchmark_config["B"]
|
113
|
+
T = extra_benchmark_config["T"]
|
114
|
+
dim = extra_benchmark_config["dim"]
|
115
|
+
dtype = extra_benchmark_config["dtype"]
|
116
|
+
|
117
|
+
x_shape = (B * T, V)
|
118
|
+
|
119
|
+
torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
|
120
|
+
liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
|
121
|
+
|
122
|
+
x = torch.randn(x_shape, dtype=dtype, device=device)
|
123
|
+
dy = torch.randn_like(x)
|
124
|
+
x.requires_grad_(True)
|
125
|
+
|
126
|
+
# utility functions
|
127
|
+
def y_fwd():
|
128
|
+
if provider == "liger":
|
129
|
+
return liger_sparsemax_module(x)
|
130
|
+
elif provider == "torch":
|
131
|
+
return torch_sparsemax_module(x)
|
132
|
+
|
133
|
+
def full():
|
134
|
+
y = y_fwd()
|
135
|
+
y.backward(dy, retain_graph=True)
|
136
|
+
|
137
|
+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
|
138
|
+
|
139
|
+
return SingleBenchmarkRunOutput(
|
140
|
+
y_20=mem_20,
|
141
|
+
y_50=mem_50,
|
142
|
+
y_80=mem_80,
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
args = parse_benchmark_script_args()
|
148
|
+
|
149
|
+
common_configs = {
|
150
|
+
"kernel_name": "sparsemax",
|
151
|
+
"x_name": "V",
|
152
|
+
"x_label": "feature size",
|
153
|
+
"x_values": [2**i for i in range(10, 16)],
|
154
|
+
"kernel_providers": ["liger", "torch"],
|
155
|
+
"extra_benchmark_configs": [{"B": 4, "T": 512, "dim": -1, "dtype": torch.float32}],
|
156
|
+
"overwrite": args.overwrite,
|
157
|
+
}
|
158
|
+
|
159
|
+
run_benchmarks(
|
160
|
+
bench_test_fn=bench_speed_sparsemax,
|
161
|
+
kernel_operation_modes=["forward", "full", "backward"],
|
162
|
+
metric_name="speed",
|
163
|
+
metric_unit="ms",
|
164
|
+
**common_configs,
|
165
|
+
)
|
166
|
+
run_benchmarks(
|
167
|
+
bench_test_fn=bench_memory_sparsemax,
|
168
|
+
kernel_operation_modes=["full"],
|
169
|
+
metric_name="memory",
|
170
|
+
metric_unit="MB",
|
171
|
+
**common_configs,
|
172
|
+
)
|
@@ -22,9 +22,6 @@ export MEDUSA_LR_MULTIPLIER=4.0
|
|
22
22
|
accelerate launch --config_file fsdp/acc-fsdp.conf \
|
23
23
|
--num_machines $NUM_NODES \
|
24
24
|
--num_processes $WORLD_SIZE \
|
25
|
-
--main_process_ip $MASTER_ADDR \
|
26
|
-
--main_process_port $MASTER_PORT \
|
27
|
-
--machine_rank $RANK \
|
28
25
|
train.py \
|
29
26
|
--bf16 True \
|
30
27
|
--output_dir $OUTPUT_DIR \
|
@@ -32,7 +29,7 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
|
|
32
29
|
--per_device_train_batch_size $LOCAL_TRAIN_BATCH_SIZE \
|
33
30
|
--per_device_eval_batch_size 1 \
|
34
31
|
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
|
35
|
-
--
|
32
|
+
--eval_strategy "no" \
|
36
33
|
--save_strategy "no" \
|
37
34
|
--prediction_loss_only \
|
38
35
|
--learning_rate $LR \
|
@@ -53,4 +50,4 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
|
|
53
50
|
--medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \
|
54
51
|
--medusa_only_heads False \
|
55
52
|
--medusa_return True \
|
56
|
-
--use_liger True
|
53
|
+
--use_liger True
|
@@ -32,21 +32,18 @@ from callback import EfficiencyCallback
|
|
32
32
|
from medusa_util import add_medusa_heads
|
33
33
|
from safetensors.torch import save_file
|
34
34
|
from sklearn.model_selection import train_test_split
|
35
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
36
|
-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
|
37
|
-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
38
35
|
from torch.utils.data import Dataset
|
39
36
|
from transformers import Trainer
|
40
37
|
from transformers.trainer_pt_utils import LabelSmoother
|
41
38
|
|
42
|
-
from liger_kernel.transformers import
|
39
|
+
from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
43
40
|
|
44
41
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
45
42
|
|
46
43
|
|
47
44
|
@dataclass
|
48
45
|
class ModelArguments:
|
49
|
-
model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B")
|
46
|
+
model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B-Instruct")
|
50
47
|
|
51
48
|
|
52
49
|
@dataclass
|
@@ -310,29 +307,36 @@ def train():
|
|
310
307
|
print(tokenizer(["This is a test", "secondary"], padding=True))
|
311
308
|
print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]))
|
312
309
|
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
# config=config,
|
317
|
-
cache_dir=training_args.cache_dir,
|
318
|
-
torch_dtype=torch.bfloat16,
|
319
|
-
)
|
310
|
+
def _model_loader():
|
311
|
+
# we use a customized model loader to inject medusa heads to FSDP-wrapped model variables properly.
|
312
|
+
# see https://github.com/linkedin/Liger-Kernel/issues/309#issuecomment-2455077623 for details.
|
320
313
|
|
321
|
-
|
322
|
-
|
314
|
+
# Load model
|
315
|
+
if training_args.use_liger:
|
316
|
+
model_builder = AutoLigerKernelForCausalLM.from_pretrained
|
317
|
+
else:
|
318
|
+
model_builder = transformers.AutoModelForCausalLM.from_pretrained
|
319
|
+
model = model_builder(
|
320
|
+
model_args.model_name_or_path,
|
321
|
+
cache_dir=training_args.cache_dir,
|
322
|
+
torch_dtype=torch.bfloat16,
|
323
|
+
)
|
323
324
|
|
324
|
-
|
325
|
-
|
326
|
-
|
325
|
+
# Freeze the base model
|
326
|
+
for param in model.base_model.parameters():
|
327
|
+
param.requires_grad = False
|
328
|
+
|
329
|
+
# Inject Medusa heads
|
330
|
+
add_medusa_heads(
|
331
|
+
model,
|
332
|
+
training_args.medusa_num_heads,
|
333
|
+
training_args.medusa_num_layers,
|
334
|
+
training_args.medusa_return,
|
335
|
+
training_args.medusa_only_heads,
|
336
|
+
training_args.use_liger,
|
337
|
+
)
|
338
|
+
return model
|
327
339
|
|
328
|
-
add_medusa_heads(
|
329
|
-
model,
|
330
|
-
training_args.medusa_num_heads,
|
331
|
-
training_args.medusa_num_layers,
|
332
|
-
training_args.medusa_return,
|
333
|
-
training_args.medusa_only_heads,
|
334
|
-
training_args.use_liger,
|
335
|
-
)
|
336
340
|
# Format output dir
|
337
341
|
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
|
338
342
|
|
@@ -341,7 +345,7 @@ def train():
|
|
341
345
|
|
342
346
|
# Start trainner
|
343
347
|
trainer = Trainer(
|
344
|
-
|
348
|
+
model_init=_model_loader,
|
345
349
|
tokenizer=tokenizer,
|
346
350
|
args=training_args,
|
347
351
|
callbacks=[EfficiencyCallback()],
|
@@ -355,17 +359,11 @@ def train():
|
|
355
359
|
|
356
360
|
if training_args.medusa_return and training_args.medusa_only_heads:
|
357
361
|
# Save only the updated head without saving the backbone model
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
with FSDP.state_dict_type(
|
364
|
-
model,
|
365
|
-
StateDictType.FULL_STATE_DICT,
|
366
|
-
FullStateDictConfig(offload_to_cpu=True),
|
367
|
-
):
|
368
|
-
state_dict = lm_head.state_dict()
|
362
|
+
state_dict = {
|
363
|
+
k.replace("medusa_head.", ""): v.to(torch.bfloat16)
|
364
|
+
for k, v in trainer.accelerator.get_state_dict(trainer.model).items()
|
365
|
+
if "medusa_head" in k
|
366
|
+
}
|
369
367
|
|
370
368
|
# Save Medusa heads
|
371
369
|
if local_rank == 0:
|
@@ -373,9 +371,9 @@ def train():
|
|
373
371
|
state_dict,
|
374
372
|
os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"),
|
375
373
|
)
|
374
|
+
trainer.accelerator.wait_for_everyone()
|
376
375
|
else:
|
377
376
|
# Save the whole model weight
|
378
|
-
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
379
377
|
trainer.save_model(training_args.output_dir)
|
380
378
|
|
381
379
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.5.9.
|
7
|
+
version = "0.5.9.dev20250515065336"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -0,0 +1,167 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
7
|
+
|
8
|
+
|
9
|
+
@triton.jit
|
10
|
+
def _sparsemax_forward_kernel(
|
11
|
+
x_ptr,
|
12
|
+
x_stride_row,
|
13
|
+
sorted_x_ptr,
|
14
|
+
sorted_x_stride_row,
|
15
|
+
o_ptr,
|
16
|
+
o_stride_row,
|
17
|
+
n_cols,
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
19
|
+
num_warps: tl.constexpr,
|
20
|
+
):
|
21
|
+
pid_row = tl.program_id(0)
|
22
|
+
ptr_x_data_row = x_ptr + pid_row * x_stride_row
|
23
|
+
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
|
24
|
+
ptr_output_row = o_ptr + pid_row * o_stride_row
|
25
|
+
|
26
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
27
|
+
mask = offs < n_cols
|
28
|
+
|
29
|
+
z_sorted_block = tl.load(
|
30
|
+
ptr_sorted_x_data_row + offs,
|
31
|
+
mask=mask,
|
32
|
+
other=-float("inf"),
|
33
|
+
cache_modifier=".ca",
|
34
|
+
).to(tl.float32)
|
35
|
+
|
36
|
+
z_valid = tl.where(mask, z_sorted_block, 0.0)
|
37
|
+
cssv = tl.cumsum(z_valid, 0)
|
38
|
+
|
39
|
+
r = (offs + 1).to(tl.float32)
|
40
|
+
safe_r = tl.where(mask, r, 1.0)
|
41
|
+
|
42
|
+
t_vec = (cssv - 1.0) / safe_r
|
43
|
+
|
44
|
+
support = (z_sorted_block > t_vec) & mask
|
45
|
+
|
46
|
+
k_int = tl.sum(support.to(tl.int32), 0)
|
47
|
+
k_clamped_int = tl.maximum(k_int, 1)
|
48
|
+
k = k_clamped_int.to(tl.float32)
|
49
|
+
|
50
|
+
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
|
51
|
+
|
52
|
+
tau = (s - 1.0) / k
|
53
|
+
|
54
|
+
x_block = tl.load(
|
55
|
+
ptr_x_data_row + offs,
|
56
|
+
mask=mask,
|
57
|
+
other=0.0,
|
58
|
+
cache_modifier=".ca",
|
59
|
+
).to(tl.float32)
|
60
|
+
|
61
|
+
y = tl.maximum(x_block - tau, 0.0)
|
62
|
+
|
63
|
+
tl.store(
|
64
|
+
ptr_output_row + offs,
|
65
|
+
y.to(ptr_output_row.dtype.element_ty),
|
66
|
+
mask=mask,
|
67
|
+
cache_modifier=".cs",
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@triton.jit
|
72
|
+
def _sparsemax_backward_kernel(
|
73
|
+
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
|
74
|
+
):
|
75
|
+
row = tl.program_id(0)
|
76
|
+
o_row = o_ptr + row * stride
|
77
|
+
go_row = go_ptr + row * stride
|
78
|
+
gi_row = gi_ptr + row * stride
|
79
|
+
|
80
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
81
|
+
|
82
|
+
supp_cnt = tl.zeros((), tl.float32)
|
83
|
+
go_sum = tl.zeros((), tl.float32)
|
84
|
+
|
85
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
86
|
+
offs_iter = i * BLOCK_SIZE + offs
|
87
|
+
mask_iter = offs_iter < n_cols
|
88
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
89
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
90
|
+
supp = o_val > 0.0
|
91
|
+
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
|
92
|
+
supp_cnt += tl.sum(supp.to(tl.float32))
|
93
|
+
|
94
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
95
|
+
offs_iter = i * BLOCK_SIZE + offs
|
96
|
+
mask_iter = offs_iter < n_cols
|
97
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
98
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
99
|
+
supp = o_val > 0.0
|
100
|
+
gi_val = tl.where(
|
101
|
+
supp,
|
102
|
+
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
|
103
|
+
0.0,
|
104
|
+
)
|
105
|
+
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
106
|
+
|
107
|
+
|
108
|
+
class LigerSparsemaxFunction(torch.autograd.Function):
|
109
|
+
@staticmethod
|
110
|
+
@ensure_contiguous
|
111
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
112
|
+
if dim < 0:
|
113
|
+
dim += x.dim()
|
114
|
+
ctx.dim = dim
|
115
|
+
|
116
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
117
|
+
n_cols = x_sw.size(-1)
|
118
|
+
n_rows = x_sw.numel() // n_cols
|
119
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
120
|
+
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
122
|
+
out_flat = torch.empty_like(x_flat)
|
123
|
+
grid = (n_rows,)
|
124
|
+
|
125
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
126
|
+
|
127
|
+
_sparsemax_forward_kernel[grid](
|
128
|
+
x_flat,
|
129
|
+
x_flat.stride(0),
|
130
|
+
x_sorted_flat,
|
131
|
+
x_sorted_flat.stride(0),
|
132
|
+
out_flat,
|
133
|
+
out_flat.stride(0),
|
134
|
+
n_cols,
|
135
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
136
|
+
num_warps=num_warps,
|
137
|
+
)
|
138
|
+
|
139
|
+
ctx.save_for_backward(out_flat)
|
140
|
+
return out_flat.view_as(x_sw).transpose(dim, -1)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
@ensure_contiguous
|
144
|
+
def backward(ctx, grad_out: torch.Tensor):
|
145
|
+
(out_flat,) = ctx.saved_tensors
|
146
|
+
dim = ctx.dim
|
147
|
+
|
148
|
+
go_sw = grad_out.transpose(dim, -1).contiguous()
|
149
|
+
n_cols = go_sw.size(-1)
|
150
|
+
n_rows = go_sw.numel() // n_cols
|
151
|
+
go_flat = go_sw.view(n_rows, n_cols)
|
152
|
+
|
153
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
154
|
+
gi_flat = torch.empty_like(go_flat)
|
155
|
+
grid = (n_rows,)
|
156
|
+
|
157
|
+
_sparsemax_backward_kernel[grid](
|
158
|
+
out_flat,
|
159
|
+
go_flat,
|
160
|
+
gi_flat,
|
161
|
+
out_flat.stride(0),
|
162
|
+
n_cols,
|
163
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
164
|
+
num_warps=num_warps,
|
165
|
+
)
|
166
|
+
|
167
|
+
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
12
12
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
13
13
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
14
14
|
from liger_kernel.ops.rope import LigerRopeFunction
|
15
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
15
16
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
16
17
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
17
18
|
|
@@ -159,6 +160,13 @@ def liger_kl_div(
|
|
159
160
|
)
|
160
161
|
|
161
162
|
|
163
|
+
def liger_sparsemax(
|
164
|
+
input,
|
165
|
+
dim: int = -1,
|
166
|
+
):
|
167
|
+
return LigerSparsemaxFunction.apply(input, dim)
|
168
|
+
|
169
|
+
|
162
170
|
def liger_tvd(
|
163
171
|
input,
|
164
172
|
target,
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
5
|
+
|
6
|
+
|
7
|
+
class LigerSparsemax(nn.Module):
|
8
|
+
def __init__(self, dim: int = -1):
|
9
|
+
super().__init__()
|
10
|
+
self.dim = dim
|
11
|
+
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
14
|
+
|
15
|
+
def extra_repr(self) -> str:
|
16
|
+
return f"dim={self.dim}"
|