liger-kernel 0.6.4__tar.gz → 0.6.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/amd-ci.yml +4 -4
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/benchmark.yml +14 -10
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/docs.yml +3 -3
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/intel-ci.yml +3 -3
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/nvi-ci.yml +6 -6
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/publish-nightly.yml +2 -2
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/workflows/publish-release.yml +2 -2
- liger_kernel-0.6.5/.pre-commit-config.yaml +10 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/PKG-INFO +11 -4
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/README.md +8 -2
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_distill_jsd_loss.py +10 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_jsd.py +8 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_qwen2vl_mrope.py +32 -18
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_tvd.py +15 -5
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/contributing.md +45 -48
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/pyproject.toml +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/setup.py +22 -2
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel-0.6.5/src/liger_kernel/ops/__init__.py +141 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/README.md +151 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel-0.6.5/src/liger_kernel/ops/backends/registry.py +61 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/cross_entropy.py +14 -4
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/dyt.py +5 -2
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/fused_add_rms_norm.py +21 -23
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/geglu.py +5 -3
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/group_norm.py +12 -8
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/kl_div.py +8 -11
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/layer_norm.py +17 -16
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/poly_norm.py +19 -21
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/rms_norm.py +149 -71
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/utils.py +25 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/__init__.py +6 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/auto_model.py +21 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/cross_entropy.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/dyt.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/experimental/embedding.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/functional.py +20 -20
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/fused_linear_jsd.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/geglu.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/group_norm.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/grpo_loss.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/jsd.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/kl_div.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/layer_norm.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel-0.6.5/src/liger_kernel/transformers/model/exaone4.py +136 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/gemma2.py +3 -3
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/gemma3.py +11 -5
- liger_kernel-0.6.5/src/liger_kernel/transformers/model/gpt_oss.py +211 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/loss_utils.py +6 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/paligemma.py +1 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/monkey_patch.py +196 -39
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/multi_token_attention.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/poly_norm.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel-0.6.5/src/liger_kernel/transformers/rope.py +64 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/softmax.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/sparsemax.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/swiglu.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/tiled_mlp.py +5 -13
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/tvd.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel.egg-info/PKG-INFO +11 -4
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel.egg-info/SOURCES.txt +18 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel.egg-info/requires.txt +2 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_cosine_loss.py +1 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_jsd_loss.py +27 -9
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/conftest.py +4 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/convergence/bf16/test_mini_models.py +122 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/convergence/bf16/test_mini_models_with_logits.py +56 -1
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/convergence/fp32/test_mini_models.py +117 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/convergence/fp32/test_mini_models_multimodal.py +18 -4
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/convergence/fp32/test_mini_models_with_logits.py +63 -4
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_auto_model.py +35 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_cross_entropy.py +1 -1
- liger_kernel-0.6.5/test/transformers/test_geglu.py +264 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_group_norm.py +5 -11
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_jsd.py +29 -1
- liger_kernel-0.6.5/test/transformers/test_llama4_rope.py +149 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_monkey_patch.py +75 -73
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_rms_norm.py +142 -16
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_tiled_mlp.py +37 -56
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/utils.py +31 -1
- liger_kernel-0.6.4/src/liger_kernel/transformers/rope.py +0 -63
- liger_kernel-0.6.4/test/convergence/fp32/__init__.py +0 -0
- liger_kernel-0.6.4/test/transformers/test_geglu.py +0 -145
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.github/pull_request_template.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/.gitignore +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/LICENSE +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/Makefile +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/NOTICE +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/README.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_poly_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_softmax.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_sparsemax.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/benchmark_tiled_mlp.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/benchmark/scripts/utils.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/dev/fmt-requirements.txt +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/dev/modal/benchmarks.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/dev/modal/tests.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/Examples.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/Getting-Started.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/High-Level-APIs.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/acknowledgement.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/banner.GIF +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/compose.gif +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/e2e-memory.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/e2e-tps.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/logo-banner.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/patch.gif +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/images/post-training.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/index.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/docs/license.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/README.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/callback.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/training.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/lightning/README.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/lightning/requirements.txt +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/lightning/training.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/README.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/callback.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/requirements.txt +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/examples/medusa/train.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/mkdocs.yml +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/setup.cfg +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/grpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/llama4_rope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/multi_token_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/softmax.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/sparsemax.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/tiled_mlp.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/fsdp.py +0 -0
- {liger_kernel-0.6.4/src/liger_kernel/ops → liger_kernel-0.6.5/src/liger_kernel/transformers/model}/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/falcon_h1.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/glm4.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/glm4v.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/glm4v_moe.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/hunyuan_v1.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/internvl.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/llama4.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/olmo3.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/output_classes.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen3.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen3_next.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen3_vl.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/qwen3_vl_moe.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/smollm3.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/model/smolvlm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/src/liger_kernel.egg-info/top_level.txt +0 -0
- {liger_kernel-0.6.4/src/liger_kernel/transformers/model → liger_kernel-0.6.5/test}/__init__.py +0 -0
- {liger_kernel-0.6.4/test → liger_kernel-0.6.5/test/chunked_loss}/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel-0.6.4/test/chunked_loss → liger_kernel-0.6.5/test/convergence}/__init__.py +0 -0
- {liger_kernel-0.6.4/test/convergence → liger_kernel-0.6.5/test/convergence/bf16}/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel-0.6.4/test/convergence/bf16 → liger_kernel-0.6.5/test/convergence/fp32}/__init__.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/Qwen/Qwen3-VL-4B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_dyt.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_embedding.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_fused_add_rms_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_fused_neighborhood_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_grpo_loss.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_multi_token_attention.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_poly_norm.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_rope.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_softmax.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_sparsemax.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_transformers.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/transformers/test_tvd.py +0 -0
- {liger_kernel-0.6.4 → liger_kernel-0.6.5}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -26,10 +26,10 @@ jobs:
|
|
|
26
26
|
|
|
27
27
|
steps:
|
|
28
28
|
- name: Checkout code
|
|
29
|
-
uses: actions/checkout@
|
|
29
|
+
uses: actions/checkout@v6
|
|
30
30
|
|
|
31
31
|
- name: Set up Python
|
|
32
|
-
uses: actions/setup-python@
|
|
32
|
+
uses: actions/setup-python@v6
|
|
33
33
|
with:
|
|
34
34
|
python-version: '3.10'
|
|
35
35
|
|
|
@@ -50,10 +50,10 @@ jobs:
|
|
|
50
50
|
|
|
51
51
|
steps:
|
|
52
52
|
- name: Checkout code
|
|
53
|
-
uses: actions/checkout@
|
|
53
|
+
uses: actions/checkout@v6
|
|
54
54
|
|
|
55
55
|
- name: Set up Python
|
|
56
|
-
uses: actions/setup-python@
|
|
56
|
+
uses: actions/setup-python@v6
|
|
57
57
|
with:
|
|
58
58
|
python-version: '3.10'
|
|
59
59
|
|
|
@@ -35,22 +35,26 @@ jobs:
|
|
|
35
35
|
OUTPUT_DIR: benchmarks
|
|
36
36
|
OUTPUT_FILENAME: benchmark.csv
|
|
37
37
|
GENERATED_CSV: benchmark/data/all_benchmark_data.csv
|
|
38
|
+
# Sanitize user-controlled inputs by declaring them as environment variables
|
|
39
|
+
# This prevents command injection attacks by filtering dangerous characters
|
|
40
|
+
INPUT_COMMIT_HASH: ${{ github.event.inputs.commit_hash }}
|
|
41
|
+
INPUT_OVERWRITE: ${{ github.event.inputs.overwrite }}
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
steps:
|
|
41
45
|
# Step: Decide the commit hash to use
|
|
42
46
|
# Step: Checkout full history so we can check out any commit
|
|
43
47
|
- name: Checkout full repo history
|
|
44
|
-
uses: actions/checkout@
|
|
48
|
+
uses: actions/checkout@v6
|
|
45
49
|
with:
|
|
46
50
|
fetch-depth: 0 # Important: so we can checkout arbitrary commit
|
|
47
51
|
|
|
48
52
|
- name: Determine commit hash to checkout
|
|
49
53
|
id: choose_commit
|
|
50
54
|
run: |
|
|
51
|
-
if [ "${{ github.event_name}}" == "workflow_dispatch" ] && [ "$
|
|
52
|
-
echo "Using manual input commit: $
|
|
53
|
-
echo "hash=$
|
|
55
|
+
if [ "${{ github.event_name}}" == "workflow_dispatch" ] && [ "$INPUT_COMMIT_HASH" != "main" ]; then
|
|
56
|
+
echo "Using manual input commit: $INPUT_COMMIT_HASH"
|
|
57
|
+
echo "hash=$INPUT_COMMIT_HASH" >> $GITHUB_OUTPUT
|
|
54
58
|
else
|
|
55
59
|
echo "Using latest commit from main"
|
|
56
60
|
echo "hash=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
|
@@ -60,10 +64,10 @@ jobs:
|
|
|
60
64
|
- name: Replace benchmark folder from main (manual only, commit ≠ main)
|
|
61
65
|
if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.commit_hash != 'main' }}
|
|
62
66
|
run: |
|
|
63
|
-
echo "Detected manual trigger with commit_hash = $
|
|
67
|
+
echo "Detected manual trigger with commit_hash = $INPUT_COMMIT_HASH"
|
|
64
68
|
|
|
65
69
|
# Save current branch (detached HEAD at old commit)
|
|
66
|
-
ORIG_COMMIT
|
|
70
|
+
ORIG_COMMIT="$INPUT_COMMIT_HASH"
|
|
67
71
|
|
|
68
72
|
# Fetch and checkout main
|
|
69
73
|
git fetch origin main
|
|
@@ -72,7 +76,7 @@ jobs:
|
|
|
72
76
|
# Save benchmark folder from main
|
|
73
77
|
cp -r benchmark /tmp/benchmark_main
|
|
74
78
|
# Checkout back to target commit
|
|
75
|
-
git checkout $ORIG_COMMIT
|
|
79
|
+
git checkout "$ORIG_COMMIT"
|
|
76
80
|
# Replace old benchmark with one from main
|
|
77
81
|
rm -rf benchmark
|
|
78
82
|
cp -r /tmp/benchmark_main benchmark
|
|
@@ -85,7 +89,7 @@ jobs:
|
|
|
85
89
|
|
|
86
90
|
if curl --output /dev/null --silent --head --fail "$BENCHMARK_URL"; then
|
|
87
91
|
echo "Benchmark already exists for commit $COMMIT_HASH"
|
|
88
|
-
if [ "$
|
|
92
|
+
if [ "$INPUT_OVERWRITE" != "true" ]; then
|
|
89
93
|
echo "Overwrite is false - exiting"
|
|
90
94
|
exit 1
|
|
91
95
|
else
|
|
@@ -96,7 +100,7 @@ jobs:
|
|
|
96
100
|
fi
|
|
97
101
|
|
|
98
102
|
- name: Set up Python
|
|
99
|
-
uses: actions/setup-python@
|
|
103
|
+
uses: actions/setup-python@v6
|
|
100
104
|
with:
|
|
101
105
|
python-version: '3.10'
|
|
102
106
|
|
|
@@ -117,7 +121,7 @@ jobs:
|
|
|
117
121
|
|
|
118
122
|
# Step 5: Checkout gh-pages branch in a subfolderAdd commentMore actions
|
|
119
123
|
- name: Checkout gh-pages
|
|
120
|
-
uses: actions/checkout@
|
|
124
|
+
uses: actions/checkout@v6
|
|
121
125
|
with:
|
|
122
126
|
ref: gh-pages
|
|
123
127
|
path: gh-pages
|
|
@@ -13,16 +13,16 @@ jobs:
|
|
|
13
13
|
deploy:
|
|
14
14
|
runs-on: ubuntu-latest
|
|
15
15
|
steps:
|
|
16
|
-
- uses: actions/checkout@
|
|
16
|
+
- uses: actions/checkout@v6
|
|
17
17
|
- name: Configure Git Credentials
|
|
18
18
|
run: |
|
|
19
19
|
git config user.name github-actions[bot]
|
|
20
20
|
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
|
21
|
-
- uses: actions/setup-python@
|
|
21
|
+
- uses: actions/setup-python@v6
|
|
22
22
|
with:
|
|
23
23
|
python-version: 3.x
|
|
24
24
|
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
|
25
|
-
- uses: actions/cache@
|
|
25
|
+
- uses: actions/cache@v5
|
|
26
26
|
with:
|
|
27
27
|
key: mkdocs-material-${{ env.cache_id }}
|
|
28
28
|
path: .cache
|
|
@@ -26,10 +26,10 @@ jobs:
|
|
|
26
26
|
|
|
27
27
|
steps:
|
|
28
28
|
- name: Checkout code
|
|
29
|
-
uses: actions/checkout@
|
|
29
|
+
uses: actions/checkout@v6
|
|
30
30
|
|
|
31
31
|
- name: Set up Python
|
|
32
|
-
uses: actions/setup-python@
|
|
32
|
+
uses: actions/setup-python@v6
|
|
33
33
|
with:
|
|
34
34
|
python-version: '3.10'
|
|
35
35
|
|
|
@@ -58,7 +58,7 @@ jobs:
|
|
|
58
58
|
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
59
59
|
|
|
60
60
|
- name: Checkout code
|
|
61
|
-
uses: actions/checkout@
|
|
61
|
+
uses: actions/checkout@v6
|
|
62
62
|
|
|
63
63
|
- name: Setup Dependencies
|
|
64
64
|
shell: bash
|
|
@@ -25,10 +25,10 @@ jobs:
|
|
|
25
25
|
|
|
26
26
|
steps:
|
|
27
27
|
- name: Checkout code
|
|
28
|
-
uses: actions/checkout@
|
|
28
|
+
uses: actions/checkout@v6
|
|
29
29
|
|
|
30
30
|
- name: Set up Python
|
|
31
|
-
uses: actions/setup-python@
|
|
31
|
+
uses: actions/setup-python@v6
|
|
32
32
|
with:
|
|
33
33
|
python-version: '3.10'
|
|
34
34
|
|
|
@@ -49,10 +49,10 @@ jobs:
|
|
|
49
49
|
|
|
50
50
|
steps:
|
|
51
51
|
- name: Checkout code
|
|
52
|
-
uses: actions/checkout@
|
|
52
|
+
uses: actions/checkout@v6
|
|
53
53
|
|
|
54
54
|
- name: Set up Python
|
|
55
|
-
uses: actions/setup-python@
|
|
55
|
+
uses: actions/setup-python@v6
|
|
56
56
|
with:
|
|
57
57
|
python-version: '3.10'
|
|
58
58
|
|
|
@@ -75,10 +75,10 @@ jobs:
|
|
|
75
75
|
|
|
76
76
|
steps:
|
|
77
77
|
- name: Checkout code
|
|
78
|
-
uses: actions/checkout@
|
|
78
|
+
uses: actions/checkout@v6
|
|
79
79
|
|
|
80
80
|
- name: Set up Python
|
|
81
|
-
uses: actions/setup-python@
|
|
81
|
+
uses: actions/setup-python@v6
|
|
82
82
|
with:
|
|
83
83
|
python-version: '3.10'
|
|
84
84
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.5
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -33,7 +33,7 @@ License-File: NOTICE
|
|
|
33
33
|
Requires-Dist: torch>=2.1.2
|
|
34
34
|
Requires-Dist: triton>=2.3.1
|
|
35
35
|
Provides-Extra: dev
|
|
36
|
-
Requires-Dist: transformers
|
|
36
|
+
Requires-Dist: transformers<5.0.0,>=4.49.0; extra == "dev"
|
|
37
37
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
|
38
38
|
Requires-Dist: ruff>=0.12.0; extra == "dev"
|
|
39
39
|
Requires-Dist: pytest>=7.1.2; extra == "dev"
|
|
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
|
|
|
45
45
|
Requires-Dist: seaborn; extra == "dev"
|
|
46
46
|
Requires-Dist: mkdocs-material; extra == "dev"
|
|
47
47
|
Requires-Dist: torchvision>=0.20; extra == "dev"
|
|
48
|
+
Requires-Dist: prek>=0.2.28; extra == "dev"
|
|
48
49
|
Dynamic: license-file
|
|
49
50
|
Dynamic: provides-extra
|
|
50
51
|
Dynamic: requires-dist
|
|
@@ -82,8 +83,8 @@ Dynamic: requires-dist
|
|
|
82
83
|
</a>
|
|
83
84
|
</td>
|
|
84
85
|
<td style="padding: 10px;">
|
|
85
|
-
<a href="https://discord.gg/
|
|
86
|
-
<img src="https://dcbadge.limes.pink/api/server/
|
|
86
|
+
<a href="https://discord.gg/X4MaxPgA">
|
|
87
|
+
<img src="https://dcbadge.limes.pink/api/server/https://discord.gg/X4MaxPgA?style=flat" alt="Join Our Discord">
|
|
87
88
|
</a>
|
|
88
89
|
</td>
|
|
89
90
|
</tr>
|
|
@@ -98,6 +99,7 @@ Dynamic: requires-dist
|
|
|
98
99
|
<details>
|
|
99
100
|
<summary>Latest News 🔥</summary>
|
|
100
101
|
|
|
102
|
+
- [2025/12/19] We announced a liger kernel discord channel at https://discord.gg/X4MaxPgA; We will be hosting Liger Kernel x Triton China Meetup in mid of January 2026
|
|
101
103
|
- [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
|
|
102
104
|
- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
|
|
103
105
|
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
@@ -116,6 +118,8 @@ We've also added optimized Post-Training kernels that deliver **up to 80% memory
|
|
|
116
118
|
|
|
117
119
|
You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
|
|
118
120
|
|
|
121
|
+
You can view the Liger Kernel Technical Report: https://openreview.net/forum?id=36SjAIT42G
|
|
122
|
+
|
|
119
123
|
## Supercharge Your Model with Liger Kernel
|
|
120
124
|
|
|
121
125
|

|
|
@@ -315,6 +319,7 @@ loss.backward()
|
|
|
315
319
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
316
320
|
| Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
317
321
|
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
322
|
+
| GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
318
323
|
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
319
324
|
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
320
325
|
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -444,3 +449,5 @@ url={https://openreview.net/forum?id=36SjAIT42G}
|
|
|
444
449
|
↑ Back to Top ↑
|
|
445
450
|
</a>
|
|
446
451
|
</p>
|
|
452
|
+
|
|
453
|
+
|
|
@@ -31,8 +31,8 @@
|
|
|
31
31
|
</a>
|
|
32
32
|
</td>
|
|
33
33
|
<td style="padding: 10px;">
|
|
34
|
-
<a href="https://discord.gg/
|
|
35
|
-
<img src="https://dcbadge.limes.pink/api/server/
|
|
34
|
+
<a href="https://discord.gg/X4MaxPgA">
|
|
35
|
+
<img src="https://dcbadge.limes.pink/api/server/https://discord.gg/X4MaxPgA?style=flat" alt="Join Our Discord">
|
|
36
36
|
</a>
|
|
37
37
|
</td>
|
|
38
38
|
</tr>
|
|
@@ -47,6 +47,7 @@
|
|
|
47
47
|
<details>
|
|
48
48
|
<summary>Latest News 🔥</summary>
|
|
49
49
|
|
|
50
|
+
- [2025/12/19] We announced a liger kernel discord channel at https://discord.gg/X4MaxPgA; We will be hosting Liger Kernel x Triton China Meetup in mid of January 2026
|
|
50
51
|
- [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
|
|
51
52
|
- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
|
|
52
53
|
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
@@ -65,6 +66,8 @@ We've also added optimized Post-Training kernels that deliver **up to 80% memory
|
|
|
65
66
|
|
|
66
67
|
You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
|
|
67
68
|
|
|
69
|
+
You can view the Liger Kernel Technical Report: https://openreview.net/forum?id=36SjAIT42G
|
|
70
|
+
|
|
68
71
|
## Supercharge Your Model with Liger Kernel
|
|
69
72
|
|
|
70
73
|

|
|
@@ -264,6 +267,7 @@ loss.backward()
|
|
|
264
267
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
265
268
|
| Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
266
269
|
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
270
|
+
| GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
267
271
|
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
268
272
|
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
269
273
|
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -393,3 +397,5 @@ url={https://openreview.net/forum?id=36SjAIT42G}
|
|
|
393
397
|
↑ Back to Top ↑
|
|
394
398
|
</a>
|
|
395
399
|
</p>
|
|
400
|
+
|
|
401
|
+
|
|
@@ -12,6 +12,7 @@ from utils import parse_benchmark_script_args
|
|
|
12
12
|
from utils import run_benchmarks
|
|
13
13
|
|
|
14
14
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
|
15
|
+
from liger_kernel.utils import get_total_gpu_memory
|
|
15
16
|
from liger_kernel.utils import infer_device
|
|
16
17
|
|
|
17
18
|
device = infer_device()
|
|
@@ -224,12 +225,20 @@ def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
224
225
|
|
|
225
226
|
if __name__ == "__main__":
|
|
226
227
|
args = parse_benchmark_script_args()
|
|
228
|
+
gpu_memory_gbs = get_total_gpu_memory()
|
|
229
|
+
# We know that the full test will require 69GBs for vocab size 2^13 and 39GBs for vocab size 2^12 on torch
|
|
230
|
+
if gpu_memory_gbs >= 69:
|
|
231
|
+
x_max = 13
|
|
232
|
+
elif gpu_memory_gbs >= 39:
|
|
233
|
+
x_max = 12
|
|
234
|
+
else:
|
|
235
|
+
x_max = 11
|
|
227
236
|
|
|
228
237
|
common_configs = {
|
|
229
238
|
"kernel_name": "distill_jsd_loss",
|
|
230
239
|
"x_name": "BT",
|
|
231
240
|
"x_label": "B x T",
|
|
232
|
-
"x_values": [2**i for i in range(10,
|
|
241
|
+
"x_values": [2**i for i in range(10, x_max + 1)],
|
|
233
242
|
"kernel_providers": ["liger", "torch"],
|
|
234
243
|
"extra_benchmark_configs": [
|
|
235
244
|
{
|
|
@@ -9,6 +9,7 @@ from utils import parse_benchmark_script_args
|
|
|
9
9
|
from utils import run_benchmarks
|
|
10
10
|
|
|
11
11
|
from liger_kernel.transformers.jsd import LigerJSD
|
|
12
|
+
from liger_kernel.utils import get_total_gpu_memory
|
|
12
13
|
from liger_kernel.utils import infer_device
|
|
13
14
|
|
|
14
15
|
device = infer_device()
|
|
@@ -123,11 +124,17 @@ def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
|
|
|
123
124
|
|
|
124
125
|
if __name__ == "__main__":
|
|
125
126
|
args = parse_benchmark_script_args()
|
|
127
|
+
gpu_memory_gbs = get_total_gpu_memory()
|
|
128
|
+
# We know that the full test will require 54GBs for vocab size 2^17 on torch
|
|
129
|
+
if gpu_memory_gbs >= 54:
|
|
130
|
+
x_max = 17
|
|
131
|
+
else:
|
|
132
|
+
x_max = 16
|
|
126
133
|
common_args = {
|
|
127
134
|
"kernel_name": "jsd",
|
|
128
135
|
"x_name": "V",
|
|
129
136
|
"x_label": "vocab size",
|
|
130
|
-
"x_values": [2**i for i in range(12,
|
|
137
|
+
"x_values": [2**i for i in range(12, x_max + 1)],
|
|
131
138
|
"kernel_providers": ["liger", "torch"],
|
|
132
139
|
"extra_benchmark_configs": [{"B": 4, "T": 2048}],
|
|
133
140
|
"overwrite": args.overwrite,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
|
+
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLTextConfig
|
|
4
5
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding
|
|
5
6
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
|
|
6
7
|
from utils import QUANTILES
|
|
@@ -32,7 +33,20 @@ def bench_speed_qwen2vl_mrope(
|
|
|
32
33
|
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
|
|
33
34
|
|
|
34
35
|
head_dim = hidden_size // num_q_heads
|
|
35
|
-
|
|
36
|
+
mrope_section_hw = head_dim * 3 // 16
|
|
37
|
+
mrope_section = [
|
|
38
|
+
head_dim // 2 - 2 * mrope_section_hw,
|
|
39
|
+
mrope_section_hw,
|
|
40
|
+
mrope_section_hw,
|
|
41
|
+
]
|
|
42
|
+
config = Qwen2VLTextConfig(
|
|
43
|
+
hidden_size=hidden_size,
|
|
44
|
+
num_attention_heads=num_q_heads,
|
|
45
|
+
num_key_value_heads=num_kv_heads,
|
|
46
|
+
rope_theta=1000000.0,
|
|
47
|
+
mrope_section=mrope_section,
|
|
48
|
+
)
|
|
49
|
+
rotary_emb = Qwen2VLRotaryEmbedding(config, device=device)
|
|
36
50
|
q = torch.randn(
|
|
37
51
|
(1, seq_len, num_q_heads, head_dim),
|
|
38
52
|
device=device,
|
|
@@ -47,18 +61,11 @@ def bench_speed_qwen2vl_mrope(
|
|
|
47
61
|
).transpose(1, 2)
|
|
48
62
|
dq, dk = (
|
|
49
63
|
torch.randn_like(q, device=device, dtype=dtype),
|
|
50
|
-
torch.randn_like(k, device=device),
|
|
64
|
+
torch.randn_like(k, device=device, dtype=dtype),
|
|
51
65
|
)
|
|
52
66
|
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
|
|
53
67
|
cos, sin = rotary_emb(k, pos_ids)
|
|
54
68
|
|
|
55
|
-
mrope_section_hw = head_dim * 3 // 16
|
|
56
|
-
mrope_section = [
|
|
57
|
-
head_dim // 2 - 2 * mrope_section_hw,
|
|
58
|
-
mrope_section_hw,
|
|
59
|
-
mrope_section_hw,
|
|
60
|
-
]
|
|
61
|
-
|
|
62
69
|
def fwd():
|
|
63
70
|
if provider == "liger":
|
|
64
71
|
return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
|
|
@@ -116,7 +123,21 @@ def bench_memory_qwen2vl_mrope(
|
|
|
116
123
|
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
|
|
117
124
|
|
|
118
125
|
head_dim = hidden_size // num_q_heads
|
|
119
|
-
|
|
126
|
+
|
|
127
|
+
mrope_section_hw = head_dim * 3 // 16
|
|
128
|
+
mrope_section = [
|
|
129
|
+
head_dim // 2 - 2 * mrope_section_hw,
|
|
130
|
+
mrope_section_hw,
|
|
131
|
+
mrope_section_hw,
|
|
132
|
+
]
|
|
133
|
+
config = Qwen2VLTextConfig(
|
|
134
|
+
hidden_size=hidden_size,
|
|
135
|
+
num_attention_heads=num_q_heads,
|
|
136
|
+
num_key_value_heads=num_kv_heads,
|
|
137
|
+
rope_theta=1000000.0,
|
|
138
|
+
mrope_section=mrope_section,
|
|
139
|
+
)
|
|
140
|
+
rotary_emb = Qwen2VLRotaryEmbedding(config, device=device)
|
|
120
141
|
q = torch.randn(
|
|
121
142
|
(1, seq_len, num_q_heads, head_dim),
|
|
122
143
|
device=device,
|
|
@@ -131,18 +152,11 @@ def bench_memory_qwen2vl_mrope(
|
|
|
131
152
|
).transpose(1, 2)
|
|
132
153
|
dq, dk = (
|
|
133
154
|
torch.randn_like(q, device=device, dtype=dtype),
|
|
134
|
-
torch.randn_like(k, device=device),
|
|
155
|
+
torch.randn_like(k, device=device, dtype=dtype),
|
|
135
156
|
)
|
|
136
157
|
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
|
|
137
158
|
cos, sin = rotary_emb(k, pos_ids)
|
|
138
159
|
|
|
139
|
-
mrope_section_hw = head_dim * 3 // 16
|
|
140
|
-
mrope_section = [
|
|
141
|
-
head_dim // 2 - 2 * mrope_section_hw,
|
|
142
|
-
mrope_section_hw,
|
|
143
|
-
mrope_section_hw,
|
|
144
|
-
]
|
|
145
|
-
|
|
146
160
|
def full():
|
|
147
161
|
if provider == "liger":
|
|
148
162
|
q_out, k_out = liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
|
|
@@ -9,6 +9,10 @@ from utils import parse_benchmark_script_args
|
|
|
9
9
|
from utils import run_benchmarks
|
|
10
10
|
|
|
11
11
|
from liger_kernel.transformers.tvd import LigerTVDLoss
|
|
12
|
+
from liger_kernel.utils import get_total_gpu_memory
|
|
13
|
+
from liger_kernel.utils import infer_device
|
|
14
|
+
|
|
15
|
+
device = infer_device()
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class TorchTVDLoss(torch.nn.Module):
|
|
@@ -40,8 +44,8 @@ def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
|
40
44
|
torch_tvd = TorchTVDLoss(reduction=reduction)
|
|
41
45
|
liger_tvd = LigerTVDLoss(reduction=reduction)
|
|
42
46
|
|
|
43
|
-
_input = torch.randn(B * T, V, requires_grad=True, device=
|
|
44
|
-
target = torch.randn(B * T, V, device=
|
|
47
|
+
_input = torch.randn(B * T, V, requires_grad=True, device=device).softmax(dim=-1)
|
|
48
|
+
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
|
|
45
49
|
|
|
46
50
|
def fwd():
|
|
47
51
|
if input.kernel_provider == "liger":
|
|
@@ -82,8 +86,8 @@ def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
|
|
|
82
86
|
V = input.x
|
|
83
87
|
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
|
|
84
88
|
|
|
85
|
-
_input = torch.randn(B * T, V, requires_grad=True, device=
|
|
86
|
-
target = torch.randn(B * T, V, device=
|
|
89
|
+
_input = torch.randn(B * T, V, requires_grad=True, device=device).softmax(dim=-1)
|
|
90
|
+
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
|
|
87
91
|
|
|
88
92
|
def fwd():
|
|
89
93
|
if input.kernel_provider == "liger":
|
|
@@ -106,11 +110,17 @@ def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
|
|
|
106
110
|
|
|
107
111
|
if __name__ == "__main__":
|
|
108
112
|
args = parse_benchmark_script_args()
|
|
113
|
+
gpu_memory_gbs = get_total_gpu_memory()
|
|
114
|
+
# We know that the full test will require 66GBs for vocab size 2^17
|
|
115
|
+
if gpu_memory_gbs >= 66:
|
|
116
|
+
x_max = 17
|
|
117
|
+
else:
|
|
118
|
+
x_max = 16
|
|
109
119
|
common_args = {
|
|
110
120
|
"kernel_name": "tvd",
|
|
111
121
|
"x_name": "V",
|
|
112
122
|
"x_label": "vocab size",
|
|
113
|
-
"x_values": [2**i for i in range(12,
|
|
123
|
+
"x_values": [2**i for i in range(12, x_max + 1)],
|
|
114
124
|
"kernel_providers": ["liger", "torch"],
|
|
115
125
|
"extra_benchmark_configs": [{"B": 8, "T": 2048}],
|
|
116
126
|
"overwrite": args.overwrite,
|