liger-kernel 0.5.3__tar.gz → 0.5.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel-0.5.5/.github/workflows/amd-ci.yml +74 -0
- liger_kernel-0.5.3/.github/workflows/amd-ci.yml → liger_kernel-0.5.5/.github/workflows/intel-ci.yml +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/Makefile +11 -5
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/PKG-INFO +19 -4
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/README.md +18 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/data/all_benchmark_data.csv +66 -30
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_distill_jsd_loss.py +2 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_kto_loss.py +10 -10
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_rope.py +1 -1
- liger_kernel-0.5.5/benchmark/scripts/benchmark_tvd.py +133 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/dev/modal/tests.py +1 -1
- liger_kernel-0.5.5/docs/images/post-training.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/lightning/training.py +1 -1
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/callback.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/pyproject.toml +1 -1
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/setup.py +13 -4
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/__init__.py +1 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/cpo_loss.py +51 -11
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/dpo_loss.py +30 -4
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel-0.5.5/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +240 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel-0.5.5/src/liger_kernel/chunked_loss/grpo_loss.py +194 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/jsd_loss.py +31 -6
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/kto_loss.py +53 -15
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/orpo_loss.py +37 -5
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/simpo_loss.py +47 -11
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/cross_entropy.py +7 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/fused_linear_jsd.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/jsd.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel-0.5.5/src/liger_kernel/ops/tvd.py +207 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/utils.py +1 -2
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/__init__.py +4 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/functional.py +17 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/group_norm.py +6 -6
- liger_kernel-0.5.5/src/liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel-0.5.5/src/liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/monkey_patch.py +239 -27
- liger_kernel-0.5.5/src/liger_kernel/transformers/tvd.py +13 -0
- liger_kernel-0.5.5/src/liger_kernel/utils.py +60 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/PKG-INFO +19 -4
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/SOURCES.txt +20 -3
- liger_kernel-0.5.5/test/chunked_loss/test_grpo_loss.py +275 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_jsd_loss.py +49 -10
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_kto_loss.py +85 -8
- liger_kernel-0.5.5/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.5/test/convergence/bf16}/test_mini_models.py +206 -36
- {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.5/test/convergence/bf16}/test_mini_models_multimodal.py +92 -28
- {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.5/test/convergence/bf16}/test_mini_models_with_logits.py +205 -35
- liger_kernel-0.5.5/test/convergence/fp32/__init__.py +0 -0
- liger_kernel-0.5.5/test/convergence/fp32/test_mini_models.py +747 -0
- liger_kernel-0.5.5/test/convergence/fp32/test_mini_models_multimodal.py +513 -0
- liger_kernel-0.5.5/test/convergence/fp32/test_mini_models_with_logits.py +746 -0
- liger_kernel-0.5.5/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +63 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_cross_entropy.py +39 -0
- liger_kernel-0.5.5/test/transformers/test_flex_attention.py +291 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_layer_norm.py +20 -5
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_monkey_patch.py +122 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_qwen2vl_mrope.py +3 -2
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_rope.py +1 -1
- liger_kernel-0.5.5/test/transformers/test_tvd.py +188 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/triton/test_triton_monkey_patch.py +3 -3
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/utils.py +36 -42
- liger_kernel-0.5.3/docs/images/post-training.png +0 -0
- liger_kernel-0.5.3/src/liger_kernel/utils.py +0 -13
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/pull_request_template.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/docs.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.gitignore +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/LICENSE +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/NOTICE +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/utils.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/dev/fmt-requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/Examples.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/Getting-Started.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/High-Level-APIs.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/acknowledgement.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/contributing.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/banner.GIF +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/compose.gif +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/e2e-memory.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/e2e-tps.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/logo-banner.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/patch.gif +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/index.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/license.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/callback.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/training.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/lightning/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/lightning/requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/requirements.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/train.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/mkdocs.yml +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/setup.cfg +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/requires.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/top_level.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/conftest.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/convergence/__init__.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_embedding.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_geglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_jsd.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_transformers.py +0 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
name: AMD GPU
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
paths:
|
|
8
|
+
- "src/**"
|
|
9
|
+
- "test/**"
|
|
10
|
+
pull_request:
|
|
11
|
+
branches:
|
|
12
|
+
- main
|
|
13
|
+
paths:
|
|
14
|
+
- "src/**"
|
|
15
|
+
- "test/**"
|
|
16
|
+
schedule:
|
|
17
|
+
# Runs at 00:00 UTC daily
|
|
18
|
+
- cron: '0 0 * * *'
|
|
19
|
+
workflow_dispatch: # Enables manual trigger
|
|
20
|
+
|
|
21
|
+
concurrency:
|
|
22
|
+
# This causes it to cancel previous in-progress actions on the same PR / branch,
|
|
23
|
+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
|
24
|
+
cancel-in-progress: true
|
|
25
|
+
|
|
26
|
+
jobs:
|
|
27
|
+
checkstyle:
|
|
28
|
+
runs-on: ubuntu-latest
|
|
29
|
+
|
|
30
|
+
steps:
|
|
31
|
+
- name: Checkout code
|
|
32
|
+
uses: actions/checkout@v3
|
|
33
|
+
|
|
34
|
+
- name: Set up Python
|
|
35
|
+
uses: actions/setup-python@v3
|
|
36
|
+
with:
|
|
37
|
+
python-version: '3.10'
|
|
38
|
+
|
|
39
|
+
- name: Install dependencies
|
|
40
|
+
run: |
|
|
41
|
+
python -m pip install --upgrade pip
|
|
42
|
+
pip install -r dev/fmt-requirements.txt
|
|
43
|
+
|
|
44
|
+
- name: Run checkstyle
|
|
45
|
+
run: make checkstyle
|
|
46
|
+
|
|
47
|
+
tests:
|
|
48
|
+
runs-on: linux-mi300-gpu-1
|
|
49
|
+
needs: [checkstyle]
|
|
50
|
+
strategy:
|
|
51
|
+
matrix:
|
|
52
|
+
rocm_version: ['6.2', '6.3']
|
|
53
|
+
|
|
54
|
+
steps:
|
|
55
|
+
- name: Checkout code
|
|
56
|
+
uses: actions/checkout@v3
|
|
57
|
+
|
|
58
|
+
- name: Set up Python
|
|
59
|
+
uses: actions/setup-python@v3
|
|
60
|
+
with:
|
|
61
|
+
python-version: '3.10'
|
|
62
|
+
|
|
63
|
+
- name: Setup Dependencies
|
|
64
|
+
run: |
|
|
65
|
+
python -m pip install --upgrade pip
|
|
66
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}
|
|
67
|
+
|
|
68
|
+
- name: List Python Environments
|
|
69
|
+
run: python -m pip list
|
|
70
|
+
|
|
71
|
+
- name: Run Unit Tests
|
|
72
|
+
run: |
|
|
73
|
+
make test
|
|
74
|
+
make test-convergence
|
liger_kernel-0.5.3/.github/workflows/amd-ci.yml → liger_kernel-0.5.5/.github/workflows/intel-ci.yml
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
name:
|
|
1
|
+
name: Intel GPU
|
|
2
2
|
|
|
3
3
|
on:
|
|
4
4
|
push:
|
|
@@ -45,7 +45,7 @@ jobs:
|
|
|
45
45
|
run: make checkstyle
|
|
46
46
|
|
|
47
47
|
tests:
|
|
48
|
-
runs-on: linux-
|
|
48
|
+
runs-on: linux-max1550-gpu-8
|
|
49
49
|
needs: [checkstyle]
|
|
50
50
|
|
|
51
51
|
steps:
|
|
@@ -60,7 +60,7 @@ jobs:
|
|
|
60
60
|
- name: Setup Dependencies
|
|
61
61
|
run: |
|
|
62
62
|
python -m pip install --upgrade pip
|
|
63
|
-
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/
|
|
63
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/test/xpu
|
|
64
64
|
|
|
65
65
|
- name: List Python Environments
|
|
66
66
|
run: python -m pip list
|
|
@@ -9,8 +9,10 @@ test:
|
|
|
9
9
|
|
|
10
10
|
# Command to run ruff for linting and formatting code
|
|
11
11
|
checkstyle:
|
|
12
|
-
ruff check
|
|
13
|
-
ruff format .; ruff_format_status=$$?; \
|
|
12
|
+
ruff check .; ruff_check_status=$$?; \
|
|
13
|
+
ruff format --check .; ruff_format_status=$$?; \
|
|
14
|
+
ruff check . --fix; \
|
|
15
|
+
ruff format .; \
|
|
14
16
|
if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
|
|
15
17
|
exit 1; \
|
|
16
18
|
fi
|
|
@@ -18,9 +20,13 @@ checkstyle:
|
|
|
18
20
|
# Command to run pytest for convergence tests
|
|
19
21
|
# We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
|
|
20
22
|
test-convergence:
|
|
21
|
-
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
|
|
22
|
-
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
|
|
23
|
-
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
|
|
23
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
|
|
24
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py
|
|
25
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py
|
|
26
|
+
|
|
27
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py
|
|
28
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
|
|
29
|
+
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
|
|
24
30
|
|
|
25
31
|
# Command to run all benchmark scripts and update benchmarking data file
|
|
26
32
|
# By default this doesn't overwrite existing data for the same benchmark experiment
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -97,6 +97,11 @@ Dynamic: requires-dist
|
|
|
97
97
|
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
98
98
|
</a>
|
|
99
99
|
</div>
|
|
100
|
+
<div style="display: block;">
|
|
101
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
102
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
103
|
+
</a>
|
|
104
|
+
</div>
|
|
100
105
|
</td>
|
|
101
106
|
</tr>
|
|
102
107
|
</table>
|
|
@@ -123,7 +128,7 @@ Dynamic: requires-dist
|
|
|
123
128
|
|
|
124
129
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
125
130
|
|
|
126
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
131
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
127
132
|
|
|
128
133
|
## Supercharge Your Model with Liger Kernel
|
|
129
134
|
|
|
@@ -149,7 +154,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
149
154
|
We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
|
|
150
155
|
|
|
151
156
|
```python
|
|
152
|
-
from liger_kernel.chunked_loss import
|
|
157
|
+
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
153
158
|
orpo_loss = LigerFusedLinearORPOLoss()
|
|
154
159
|
y = orpo_loss(lm_head.weight, x, target)
|
|
155
160
|
```
|
|
@@ -188,6 +193,11 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
188
193
|
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
|
|
189
194
|
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
190
195
|
|
|
196
|
+
```bash
|
|
197
|
+
# Need to pass the url when installing
|
|
198
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
|
199
|
+
```
|
|
200
|
+
|
|
191
201
|
### Optional Dependencies
|
|
192
202
|
|
|
193
203
|
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
@@ -304,7 +314,10 @@ loss.backward()
|
|
|
304
314
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
305
315
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
306
316
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
317
|
+
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
307
318
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
319
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
320
|
+
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
308
321
|
|
|
309
322
|
|
|
310
323
|
## Low-level APIs
|
|
@@ -333,6 +346,7 @@ loss.backward()
|
|
|
333
346
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
334
347
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
335
348
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
349
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
|
336
350
|
|
|
337
351
|
### Distillation Kernels
|
|
338
352
|
|
|
@@ -341,6 +355,7 @@ loss.backward()
|
|
|
341
355
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
342
356
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
343
357
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
358
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
|
344
359
|
|
|
345
360
|
### Experimental Kernels
|
|
346
361
|
|
|
@@ -372,7 +387,7 @@ loss.backward()
|
|
|
372
387
|
|
|
373
388
|
- For issues, create a Github ticket in this repository
|
|
374
389
|
- For open discussion, join [our discord channel](https://discord.gg/gpumode)
|
|
375
|
-
- For formal collaboration, send an email to
|
|
390
|
+
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
376
391
|
|
|
377
392
|
## Cite this work
|
|
378
393
|
|
|
@@ -47,6 +47,11 @@
|
|
|
47
47
|
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
48
48
|
</a>
|
|
49
49
|
</div>
|
|
50
|
+
<div style="display: block;">
|
|
51
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
52
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
53
|
+
</a>
|
|
54
|
+
</div>
|
|
50
55
|
</td>
|
|
51
56
|
</tr>
|
|
52
57
|
</table>
|
|
@@ -73,7 +78,7 @@
|
|
|
73
78
|
|
|
74
79
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
75
80
|
|
|
76
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
81
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
77
82
|
|
|
78
83
|
## Supercharge Your Model with Liger Kernel
|
|
79
84
|
|
|
@@ -99,7 +104,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
99
104
|
We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
|
|
100
105
|
|
|
101
106
|
```python
|
|
102
|
-
from liger_kernel.chunked_loss import
|
|
107
|
+
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
103
108
|
orpo_loss = LigerFusedLinearORPOLoss()
|
|
104
109
|
y = orpo_loss(lm_head.weight, x, target)
|
|
105
110
|
```
|
|
@@ -138,6 +143,11 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
138
143
|
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
|
|
139
144
|
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
140
145
|
|
|
146
|
+
```bash
|
|
147
|
+
# Need to pass the url when installing
|
|
148
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
|
149
|
+
```
|
|
150
|
+
|
|
141
151
|
### Optional Dependencies
|
|
142
152
|
|
|
143
153
|
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
@@ -254,7 +264,10 @@ loss.backward()
|
|
|
254
264
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
255
265
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
256
266
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
267
|
+
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
257
268
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
269
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
270
|
+
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
258
271
|
|
|
259
272
|
|
|
260
273
|
## Low-level APIs
|
|
@@ -283,6 +296,7 @@ loss.backward()
|
|
|
283
296
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
284
297
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
285
298
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
299
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
|
286
300
|
|
|
287
301
|
### Distillation Kernels
|
|
288
302
|
|
|
@@ -291,6 +305,7 @@ loss.backward()
|
|
|
291
305
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
292
306
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
293
307
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
308
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
|
294
309
|
|
|
295
310
|
### Experimental Kernels
|
|
296
311
|
|
|
@@ -322,7 +337,7 @@ loss.backward()
|
|
|
322
337
|
|
|
323
338
|
- For issues, create a Github ticket in this repository
|
|
324
339
|
- For open discussion, join [our discord channel](https://discord.gg/gpumode)
|
|
325
|
-
- For formal collaboration, send an email to
|
|
340
|
+
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
326
341
|
|
|
327
342
|
## Cite this work
|
|
328
343
|
|
|
@@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
|
|
|
505
505
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
|
506
506
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
|
507
507
|
fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
|
|
508
|
+
tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
509
|
+
tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
510
|
+
tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
511
|
+
tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
512
|
+
tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
513
|
+
tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
514
|
+
tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
515
|
+
tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
516
|
+
tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
517
|
+
tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
518
|
+
tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
519
|
+
tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
|
|
520
|
+
tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
521
|
+
tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
522
|
+
tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
523
|
+
tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
524
|
+
tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
525
|
+
tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
|
|
526
|
+
tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
527
|
+
tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
528
|
+
tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
529
|
+
tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
530
|
+
tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
531
|
+
tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
|
|
532
|
+
tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
533
|
+
tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
534
|
+
tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
535
|
+
tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
536
|
+
tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
537
|
+
tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
|
|
538
|
+
tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
539
|
+
tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
540
|
+
tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
541
|
+
tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
542
|
+
tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
543
|
+
tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
|
|
508
544
|
group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
|
509
545
|
group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
|
510
546
|
group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
|
|
@@ -715,36 +751,6 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
|
|
|
715
751
|
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
|
|
716
752
|
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
|
|
717
753
|
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
|
|
718
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,7.841599941253662,7.801983833312988,7.849664211273193,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
719
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,15.568096160888672,15.555737495422363,16.054176330566406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
720
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,31.145376205444336,30.750951766967773,31.5398006439209,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
721
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,61.49708938598633,61.49708938598633,61.49708938598633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
722
|
-
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,122.01449584960938,122.01449584960938,122.01449584960938,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
|
|
723
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.892335891723633,7.8687615394592285,8.03729248046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
724
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,14.16302490234375,13.813311576843262,15.860223770141602,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
725
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,25.56470489501953,25.564167022705078,25.641658782958984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
726
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,53.0928955078125,53.0928955078125,53.0928955078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
727
|
-
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,108.76080322265625,108.76080322265625,108.76080322265625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
|
|
728
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,8.662687301635742,8.488287925720215,9.611334800720215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
729
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,18.40096092224121,17.99224281311035,18.57883644104004,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
730
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.09159851074219,31.708070755004883,32.475128173828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
731
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,69.30239868164062,69.30239868164062,69.30239868164062,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
732
|
-
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,124.2437744140625,124.2437744140625,124.2437744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
|
|
733
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,11.449472427368164,11.407564163208008,11.773555755615234,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
734
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,20.871471405029297,20.862951278686523,20.879276275634766,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
735
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,41.16409683227539,40.760780334472656,41.567413330078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
736
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,77.720703125,77.720703125,77.720703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
737
|
-
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,156.25794982910156,156.25794982910156,156.25794982910156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
|
|
738
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2027.48583984375,2027.48583984375,2027.48583984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
739
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,2789.736328125,2789.736328125,2789.736328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
740
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,2801.751953125,2801.751953125,2801.751953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
741
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,2825.783203125,2825.783203125,2825.783203125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
742
|
-
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,2873.845703125,2873.845703125,2873.845703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
|
|
743
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,3786.7373046875,3786.7373046875,3786.7373046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
744
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.25390625,5544.25390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
745
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
746
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
747
|
-
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
|
|
748
754
|
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
|
749
755
|
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
|
750
756
|
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
|
|
@@ -769,3 +775,33 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
|
|
|
769
775
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
770
776
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
771
777
|
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
|
|
778
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,3.9951679706573486,3.991487979888916,4.002252578735352,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
779
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,7.8037919998168945,7.788575649261475,7.808595180511475,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
780
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,15.43172836303711,15.430015563964844,15.4335355758667,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
781
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,30.66864013671875,30.66431999206543,30.670501708984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
782
|
+
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,61.1163215637207,61.1163215637207,61.1163215637207,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
|
|
783
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,3.8766400814056396,3.8680384159088135,3.8897151947021484,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
784
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,7.213727951049805,7.206470489501953,7.229574680328369,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
785
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,13.828800201416016,13.810944557189941,13.834943771362305,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
786
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,27.0930233001709,27.08517074584961,27.09713363647461,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
787
|
+
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,54.13715362548828,54.13715362548828,54.13715362548828,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
|
|
788
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,4.782928466796875,4.677459239959717,5.3430914878845215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
789
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,8.517248153686523,8.481344223022461,8.561504364013672,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
790
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,16.547504425048828,16.513471603393555,16.678144454956055,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
791
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,31.891263961791992,31.819705963134766,32.274131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
792
|
+
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,62.953758239746094,62.953758239746094,62.953758239746094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
|
|
793
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,6.201632022857666,6.163315296173096,6.314668655395508,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
794
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,11.156224250793457,11.142304420471191,11.207296371459961,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
795
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,21.249855041503906,21.231891632080078,21.264543533325195,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
796
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,41.55686569213867,41.536956787109375,41.57677459716797,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
797
|
+
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,81.56924438476562,81.56924438476562,81.56924438476562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
|
|
798
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2585.73876953125,2585.73876953125,2585.73876953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
799
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3348.9892578125,3348.9892578125,3348.9892578125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
800
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,3361.0048828125,3361.0048828125,3361.0048828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
801
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,3385.0361328125,3385.0361328125,3385.0361328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
802
|
+
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,3433.0986328125,3433.0986328125,3433.0986328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
|
|
803
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4341.74951171875,4341.74951171875,4341.74951171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
804
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.26513671875,6099.26513671875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
805
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
806
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
807
|
+
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
|
|
@@ -103,8 +103,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
103
103
|
H=H,
|
|
104
104
|
V=V,
|
|
105
105
|
dtype=dtype,
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
use_bias=bias,
|
|
107
|
+
use_ref_bias=bias,
|
|
108
108
|
ignore_index=ignore_index,
|
|
109
109
|
beta=beta,
|
|
110
110
|
).to(device)
|
|
@@ -113,8 +113,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
113
113
|
H=H,
|
|
114
114
|
V=V,
|
|
115
115
|
dtype=dtype,
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
use_bias=bias,
|
|
117
|
+
use_ref_bias=bias,
|
|
118
118
|
ignore_index=ignore_index,
|
|
119
119
|
beta=beta,
|
|
120
120
|
).to(device)
|
|
@@ -149,7 +149,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
149
149
|
y=target,
|
|
150
150
|
preference_labels=preference_labels,
|
|
151
151
|
kl=kl,
|
|
152
|
-
)
|
|
152
|
+
)[0]
|
|
153
153
|
elif provider == "huggingface":
|
|
154
154
|
return torch_kto_loss(
|
|
155
155
|
x=_input,
|
|
@@ -157,7 +157,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
|
157
157
|
y=target,
|
|
158
158
|
preference_labels=preference_labels,
|
|
159
159
|
kl=kl,
|
|
160
|
-
)
|
|
160
|
+
)[0]
|
|
161
161
|
|
|
162
162
|
def full():
|
|
163
163
|
y = fwd()
|
|
@@ -189,7 +189,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
189
189
|
dtype=dtype,
|
|
190
190
|
beta=beta,
|
|
191
191
|
ignore_index=ignore_index,
|
|
192
|
-
|
|
192
|
+
use_bias=bias,
|
|
193
193
|
).to(device)
|
|
194
194
|
liger_kto_loss = LigerLMHeadKTO(
|
|
195
195
|
H=H,
|
|
@@ -197,7 +197,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
197
197
|
dtype=dtype,
|
|
198
198
|
beta=beta,
|
|
199
199
|
ignore_index=ignore_index,
|
|
200
|
-
|
|
200
|
+
use_bias=bias,
|
|
201
201
|
).to(device)
|
|
202
202
|
|
|
203
203
|
# Input shape: [B, T, H]
|
|
@@ -230,7 +230,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
230
230
|
y=target,
|
|
231
231
|
preference_labels=preference_labels,
|
|
232
232
|
kl=kl,
|
|
233
|
-
)
|
|
233
|
+
)[0]
|
|
234
234
|
elif provider == "huggingface":
|
|
235
235
|
return torch_kto_loss(
|
|
236
236
|
x=_input,
|
|
@@ -238,7 +238,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
|
238
238
|
y=target,
|
|
239
239
|
preference_labels=preference_labels,
|
|
240
240
|
kl=kl,
|
|
241
|
-
)
|
|
241
|
+
)[0]
|
|
242
242
|
|
|
243
243
|
if mode == "forward":
|
|
244
244
|
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
|
-
from test.utils import transformers_version_dispatch
|
|
5
4
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
6
5
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
7
6
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
@@ -14,6 +13,7 @@ from utils import run_benchmarks
|
|
|
14
13
|
|
|
15
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
16
15
|
from liger_kernel.utils import infer_device
|
|
16
|
+
from liger_kernel.utils import transformers_version_dispatch
|
|
17
17
|
|
|
18
18
|
device = infer_device()
|
|
19
19
|
|