liger-kernel 0.5.8__tar.gz → 0.5.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/intel-ci.yml +24 -11
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/PKG-INFO +3 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/README.md +2 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/mkdocs.yml +3 -3
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/pyproject.toml +1 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/setup.py +0 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/dpo_loss.py +8 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/cross_entropy.py +4 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/__init__.py +6 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/gemma.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/gemma2.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/gemma3.py +3 -1
- liger_kernel-0.5.9/src/liger_kernel/transformers/model/glm4.py +125 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/llama.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/mistral.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/mixtral.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/mllama.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/olmo2.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/phi3.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/qwen2.py +8 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/qwen2_5_vl.py +3 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/qwen2_vl.py +3 -1
- liger_kernel-0.5.9/src/liger_kernel/transformers/model/qwen3.py +118 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/monkey_patch.py +121 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/PKG-INFO +3 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/SOURCES.txt +2 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/test_mini_models.py +119 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/test_mini_models_multimodal.py +3 -4
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/test_mini_models_with_logits.py +119 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/test_mini_models.py +113 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/test_mini_models_multimodal.py +7 -6
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/test_mini_models_with_logits.py +113 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_cross_entropy.py +108 -2
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_fused_linear_cross_entropy.py +5 -6
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_monkey_patch.py +100 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_rms_norm.py +1 -1
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/utils.py +24 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/pull_request_template.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/docs.yml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.gitignore +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/LICENSE +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/Makefile +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/NOTICE +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/README.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/utils.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/dev/fmt-requirements.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/dev/modal/tests.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/Examples.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/Getting-Started.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/High-Level-APIs.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/acknowledgement.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/contributing.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/banner.GIF +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/compose.gif +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/e2e-memory.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/e2e-tps.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/logo-banner.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/patch.gif +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/post-training.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/index.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/license.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/README.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/callback.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/training.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/lightning/README.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/lightning/requirements.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/lightning/training.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/README.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/callback.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/requirements.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/train.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/setup.cfg +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/gema3_rms.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/llava.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/utils.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/requires.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/top_level.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/conftest.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_dyt.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_embedding.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_geglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_jsd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_rope.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_transformers.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_tvd.py +0 -0
- {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -45,27 +45,40 @@ jobs:
|
|
|
45
45
|
run: make checkstyle
|
|
46
46
|
|
|
47
47
|
tests:
|
|
48
|
-
runs-on: linux-max1550-
|
|
48
|
+
runs-on: linux-max1550-pvc-8
|
|
49
49
|
needs: [checkstyle]
|
|
50
|
-
|
|
50
|
+
if: success()
|
|
51
|
+
container:
|
|
52
|
+
image: intel/oneapi-basekit:2025.0.1-0-devel-ubuntu24.04
|
|
53
|
+
options: --privileged -v /dev/dri/by-path:/dev/dri/by-path --device=/dev/dri --ipc=host
|
|
51
54
|
steps:
|
|
55
|
+
- name: Set up python
|
|
56
|
+
shell: bash
|
|
57
|
+
run: |
|
|
58
|
+
apt-get update && \
|
|
59
|
+
apt-get install -y python3.12-venv python3-pip && \
|
|
60
|
+
ln -sf /usr/bin/python3 /usr/bin/python && \
|
|
61
|
+
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
62
|
+
|
|
52
63
|
- name: Checkout code
|
|
53
64
|
uses: actions/checkout@v3
|
|
54
|
-
|
|
55
|
-
- name: Set up Python
|
|
56
|
-
uses: actions/setup-python@v3
|
|
57
|
-
with:
|
|
58
|
-
python-version: '3.10'
|
|
59
|
-
|
|
65
|
+
|
|
60
66
|
- name: Setup Dependencies
|
|
67
|
+
shell: bash
|
|
61
68
|
run: |
|
|
62
|
-
python -m
|
|
69
|
+
python -m venv test-env
|
|
70
|
+
. test-env/bin/activate
|
|
63
71
|
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/test/xpu
|
|
64
|
-
|
|
72
|
+
|
|
65
73
|
- name: List Python Environments
|
|
66
|
-
|
|
74
|
+
shell: bash
|
|
75
|
+
run: |
|
|
76
|
+
. test-env/bin/activate
|
|
77
|
+
python -m pip list
|
|
67
78
|
|
|
68
79
|
- name: Run Unit Tests
|
|
80
|
+
shell: bash
|
|
69
81
|
run: |
|
|
82
|
+
. test-env/bin/activate
|
|
70
83
|
make test
|
|
71
84
|
make test-convergence
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.9
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -320,9 +320,11 @@ loss.backward()
|
|
|
320
320
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
321
321
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
322
322
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
323
|
+
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
323
324
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
324
325
|
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
325
326
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
327
|
+
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
326
328
|
|
|
327
329
|
|
|
328
330
|
## Low-level APIs
|
|
@@ -269,9 +269,11 @@ loss.backward()
|
|
|
269
269
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
270
270
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
271
271
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
272
|
+
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
272
273
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
273
274
|
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
274
275
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
276
|
+
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
275
277
|
|
|
276
278
|
|
|
277
279
|
## Low-level APIs
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
site_name: Liger-Kernel Docs
|
|
2
|
-
site_url:
|
|
3
|
-
site_author:
|
|
2
|
+
# site_url: ...
|
|
3
|
+
# site_author: LinkedIn
|
|
4
4
|
site_description: Efficient Triton Kernels for LLM Training
|
|
5
5
|
theme:
|
|
6
6
|
name: material
|
|
@@ -66,4 +66,4 @@ edit_uri: edit/main/docs/
|
|
|
66
66
|
extra:
|
|
67
67
|
social:
|
|
68
68
|
- icon: simple/github
|
|
69
|
-
link: https://github.com/linkedin/Liger-Kernel
|
|
69
|
+
link: https://github.com/linkedin/Liger-Kernel
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel"
|
|
7
|
-
version = "0.5.
|
|
7
|
+
version = "0.5.9"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
68
|
compute_nll_loss=False,
|
|
69
69
|
compiled=True,
|
|
70
70
|
use_ref_model=True,
|
|
71
|
+
average_log_prob=False,
|
|
71
72
|
chunk_size=1,
|
|
72
73
|
):
|
|
73
74
|
"""
|
|
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
85
86
|
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
86
87
|
compiled (bool): Whether to use torch compile
|
|
87
88
|
use_ref_model (bool): Whether to use a reference model
|
|
89
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
88
90
|
chunk_size (int): Size of chunks for processing.
|
|
89
91
|
Returns:
|
|
90
92
|
torch.Tensor: Computed loss
|
|
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
104
106
|
ref_input=ref_input,
|
|
105
107
|
ref_weight=ref_weight,
|
|
106
108
|
ref_bias=ref_bias,
|
|
109
|
+
average_log_prob=average_log_prob,
|
|
107
110
|
chunk_size=chunk_size,
|
|
108
111
|
)
|
|
109
112
|
|
|
110
113
|
@staticmethod
|
|
111
114
|
def backward(ctx, *grad_output):
|
|
112
115
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
113
|
-
return *grads, None, None, None, None, None, None, None, None, None
|
|
116
|
+
return *grads, None, None, None, None, None, None, None, None, None, None
|
|
114
117
|
|
|
115
118
|
|
|
116
119
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
125
128
|
compute_nll_loss: bool = False,
|
|
126
129
|
compiled: bool = True,
|
|
127
130
|
use_ref_model: bool = True,
|
|
131
|
+
average_log_prob: bool = True,
|
|
128
132
|
chunk_size: int = 1,
|
|
129
133
|
):
|
|
130
134
|
"""
|
|
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
134
138
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
135
139
|
compiled (bool): Whether to use the torch compiled kernel.
|
|
136
140
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
141
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
137
142
|
chunk_size (int): Size of chunks for processing.
|
|
138
143
|
"""
|
|
139
144
|
super().__init__()
|
|
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
142
147
|
self.compute_nll_loss = compute_nll_loss
|
|
143
148
|
self.compiled = compiled
|
|
144
149
|
self.use_ref_model = use_ref_model
|
|
150
|
+
self.average_log_prob = average_log_prob
|
|
145
151
|
self.chunk_size = chunk_size
|
|
146
152
|
|
|
147
153
|
def forward(
|
|
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
167
173
|
self.compute_nll_loss,
|
|
168
174
|
self.compiled,
|
|
169
175
|
self.use_ref_model,
|
|
176
|
+
self.average_log_prob,
|
|
170
177
|
self.chunk_size,
|
|
171
178
|
)
|
|
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
|
|
|
351
351
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
352
352
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
353
353
|
pass
|
|
354
|
-
|
|
354
|
+
# If reduction is 'none'
|
|
355
|
+
elif grad_output.ndim > 0:
|
|
356
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
|
357
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
|
355
358
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
356
359
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
357
360
|
else:
|
{liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/fused_linear_cross_entropy.py
RENAMED
|
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
143
143
|
alpha=1.0,
|
|
144
144
|
)
|
|
145
145
|
|
|
146
|
-
if reduction
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
|
|
147
|
+
# if reduction == "none":
|
|
148
|
+
# loss = loss_1d
|
|
149
|
+
# z_loss = z_loss_1d if return_z_loss else None
|
|
149
150
|
|
|
150
151
|
else:
|
|
151
152
|
loss = torch.sum(loss_1d)
|
|
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
|
|
26
26
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
27
27
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
|
|
28
28
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
|
|
29
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
29
30
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
30
31
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
31
32
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
|
|
|
38
39
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
39
40
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
40
41
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
42
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
# Check if 'transformers' is installed
|
|
@@ -79,6 +81,7 @@ def __getattr__(name: str):
|
|
|
79
81
|
"apply_liger_kernel_to_gemma2",
|
|
80
82
|
"apply_liger_kernel_to_gemma3",
|
|
81
83
|
"apply_liger_kernel_to_gemma3_text",
|
|
84
|
+
"apply_liger_kernel_to_glm4",
|
|
82
85
|
"apply_liger_kernel_to_granite",
|
|
83
86
|
"apply_liger_kernel_to_llama",
|
|
84
87
|
"apply_liger_kernel_to_llava",
|
|
@@ -91,6 +94,7 @@ def __getattr__(name: str):
|
|
|
91
94
|
"apply_liger_kernel_to_qwen2",
|
|
92
95
|
"apply_liger_kernel_to_qwen2_5_vl",
|
|
93
96
|
"apply_liger_kernel_to_qwen2_vl",
|
|
97
|
+
"apply_liger_kernel_to_qwen3",
|
|
94
98
|
}
|
|
95
99
|
|
|
96
100
|
if name in monkey_patch_symbols:
|
|
@@ -129,6 +133,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
129
133
|
"apply_liger_kernel_to_gemma2",
|
|
130
134
|
"apply_liger_kernel_to_gemma3",
|
|
131
135
|
"apply_liger_kernel_to_gemma3_text",
|
|
136
|
+
"apply_liger_kernel_to_glm4",
|
|
132
137
|
"apply_liger_kernel_to_granite",
|
|
133
138
|
"apply_liger_kernel_to_llama",
|
|
134
139
|
"apply_liger_kernel_to_llava",
|
|
@@ -141,5 +146,6 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
141
146
|
"apply_liger_kernel_to_qwen2",
|
|
142
147
|
"apply_liger_kernel_to_qwen2_5_vl",
|
|
143
148
|
"apply_liger_kernel_to_qwen2_vl",
|
|
149
|
+
"apply_liger_kernel_to_qwen3",
|
|
144
150
|
]
|
|
145
151
|
)
|
|
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
23
23
|
assert reduction in {
|
|
24
24
|
"mean",
|
|
25
25
|
"sum",
|
|
26
|
-
|
|
27
|
-
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
26
|
+
}, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
|
|
28
27
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
29
28
|
self.ce_weight = ce_weight
|
|
30
29
|
self.ignore_index = ignore_index
|
|
@@ -200,21 +200,25 @@ def lce_forward(
|
|
|
200
200
|
)
|
|
201
201
|
|
|
202
202
|
hidden_states = outputs[0]
|
|
203
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
204
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
205
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
203
206
|
|
|
207
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
204
208
|
logits = None
|
|
205
209
|
loss = None
|
|
206
210
|
# if in training mode, don't materialize logits
|
|
207
|
-
if self.training and (labels is not None):
|
|
211
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
208
212
|
loss = LigerForCausalLMLoss(
|
|
209
|
-
hidden_states=
|
|
213
|
+
hidden_states=kept_hidden_states,
|
|
210
214
|
lm_head_weight=self.lm_head.weight,
|
|
211
215
|
labels=labels,
|
|
216
|
+
shift_labels=shift_labels,
|
|
212
217
|
hidden_size=self.config.hidden_size,
|
|
213
218
|
**loss_kwargs,
|
|
214
219
|
)
|
|
215
220
|
else: # if in inference mode materialize logits
|
|
216
|
-
|
|
217
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
221
|
+
logits = self.lm_head(kept_hidden_states)
|
|
218
222
|
if labels is not None:
|
|
219
223
|
loss = self.loss_function(
|
|
220
224
|
logits=logits,
|
|
@@ -212,23 +212,27 @@ def lce_forward(
|
|
|
212
212
|
)
|
|
213
213
|
|
|
214
214
|
hidden_states = outputs[0]
|
|
215
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
216
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
217
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
215
218
|
|
|
219
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
216
220
|
logits = None
|
|
217
221
|
loss = None
|
|
218
222
|
# if in training mode, don't materialize logits
|
|
219
|
-
if self.training and (labels is not None):
|
|
223
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
220
224
|
loss = LigerForCausalLMLoss(
|
|
221
|
-
hidden_states=
|
|
225
|
+
hidden_states=kept_hidden_states,
|
|
222
226
|
lm_head_weight=self.lm_head.weight,
|
|
223
227
|
labels=labels,
|
|
228
|
+
shift_labels=shift_labels,
|
|
224
229
|
hidden_size=self.config.hidden_size,
|
|
225
230
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
226
231
|
**loss_kwargs,
|
|
227
232
|
)
|
|
228
233
|
|
|
229
234
|
else: # if in inference mode materialize logits
|
|
230
|
-
|
|
231
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
235
|
+
logits = self.lm_head(kept_hidden_states)
|
|
232
236
|
if self.config.final_logit_softcapping is not None:
|
|
233
237
|
logits = logits / self.config.final_logit_softcapping
|
|
234
238
|
logits = torch.tanh(logits)
|
|
@@ -104,13 +104,15 @@ def causal_forward(
|
|
|
104
104
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
105
105
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
106
106
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
107
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
107
108
|
loss = None
|
|
108
109
|
logits = None
|
|
109
|
-
if self.training and (labels is not None):
|
|
110
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
110
111
|
loss = LigerForCausalLMLoss(
|
|
111
112
|
hidden_states=kept_hidden_states,
|
|
112
113
|
lm_head_weight=self.lm_head.weight,
|
|
113
114
|
labels=labels,
|
|
115
|
+
shift_labels=shift_labels,
|
|
114
116
|
hidden_size=self.config.hidden_size,
|
|
115
117
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
116
118
|
**loss_kwargs,
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
+
from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
|
|
10
|
+
from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
|
|
11
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
+
from transformers.utils import replace_return_docstrings
|
|
13
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
|
+
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
|
+
@add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
|
|
20
|
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
|
+
def lce_forward(
|
|
22
|
+
self,
|
|
23
|
+
input_ids: torch.LongTensor = None,
|
|
24
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
25
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
26
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
27
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
28
|
+
labels: Optional[torch.LongTensor] = None,
|
|
29
|
+
use_cache: Optional[bool] = None,
|
|
30
|
+
output_attentions: Optional[bool] = None,
|
|
31
|
+
output_hidden_states: Optional[bool] = None,
|
|
32
|
+
return_dict: Optional[bool] = None,
|
|
33
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
34
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
35
|
+
**loss_kwargs,
|
|
36
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
37
|
+
r"""
|
|
38
|
+
Args:
|
|
39
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
40
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
41
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
42
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
43
|
+
|
|
44
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
45
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
46
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
47
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
48
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
49
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
>>> from transformers import AutoTokenizer, Glm4ForCausalLM
|
|
57
|
+
|
|
58
|
+
>>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
|
|
59
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
|
|
60
|
+
|
|
61
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
62
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
63
|
+
|
|
64
|
+
>>> # Generate
|
|
65
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
66
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
67
|
+
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
|
|
68
|
+
```
|
|
69
|
+
"""
|
|
70
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
71
|
+
output_hidden_states = (
|
|
72
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
73
|
+
)
|
|
74
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
75
|
+
|
|
76
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
77
|
+
outputs = self.model(
|
|
78
|
+
input_ids=input_ids,
|
|
79
|
+
attention_mask=attention_mask,
|
|
80
|
+
position_ids=position_ids,
|
|
81
|
+
past_key_values=past_key_values,
|
|
82
|
+
inputs_embeds=inputs_embeds,
|
|
83
|
+
use_cache=use_cache,
|
|
84
|
+
output_attentions=output_attentions,
|
|
85
|
+
output_hidden_states=output_hidden_states,
|
|
86
|
+
return_dict=return_dict,
|
|
87
|
+
cache_position=cache_position,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
hidden_states = outputs[0]
|
|
91
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
92
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
93
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
94
|
+
|
|
95
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
96
|
+
logits = None
|
|
97
|
+
loss = None
|
|
98
|
+
# if in training mode, don't materialize logits
|
|
99
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
100
|
+
loss = LigerForCausalLMLoss(
|
|
101
|
+
hidden_states=kept_hidden_states,
|
|
102
|
+
lm_head_weight=self.lm_head.weight,
|
|
103
|
+
labels=labels,
|
|
104
|
+
shift_labels=shift_labels,
|
|
105
|
+
hidden_size=self.config.hidden_size,
|
|
106
|
+
**loss_kwargs,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
else: # if in inference mode materialize logits
|
|
110
|
+
logits = self.lm_head(kept_hidden_states)
|
|
111
|
+
if labels is not None:
|
|
112
|
+
loss = self.loss_function(
|
|
113
|
+
logits=logits,
|
|
114
|
+
labels=labels,
|
|
115
|
+
vocab_size=self.config.vocab_size,
|
|
116
|
+
**loss_kwargs,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return CausalLMOutputWithPast(
|
|
120
|
+
loss=loss,
|
|
121
|
+
logits=logits,
|
|
122
|
+
past_key_values=outputs.past_key_values,
|
|
123
|
+
hidden_states=outputs.hidden_states,
|
|
124
|
+
attentions=outputs.attentions,
|
|
125
|
+
)
|
|
@@ -209,25 +209,29 @@ def lce_forward(
|
|
|
209
209
|
)
|
|
210
210
|
|
|
211
211
|
hidden_states = outputs[0]
|
|
212
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
213
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
214
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
212
215
|
|
|
213
216
|
if self.config.pretraining_tp > 1:
|
|
214
217
|
raise Exception("Liger Kernel does not support pretraining_tp!!")
|
|
215
218
|
|
|
219
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
216
220
|
logits = None
|
|
217
221
|
loss = None
|
|
218
222
|
# if in training mode, don't materialize logits
|
|
219
|
-
if self.training and (labels is not None):
|
|
223
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
220
224
|
loss = LigerForCausalLMLoss(
|
|
221
|
-
hidden_states=
|
|
225
|
+
hidden_states=kept_hidden_states,
|
|
222
226
|
lm_head_weight=self.lm_head.weight,
|
|
223
227
|
labels=labels,
|
|
228
|
+
shift_labels=shift_labels,
|
|
224
229
|
hidden_size=self.config.hidden_size,
|
|
225
230
|
**loss_kwargs,
|
|
226
231
|
)
|
|
227
232
|
|
|
228
233
|
else: # if in inference mode materialize logits
|
|
229
|
-
|
|
230
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
234
|
+
logits = self.lm_head(kept_hidden_states)
|
|
231
235
|
if labels is not None:
|
|
232
236
|
loss = self.loss_function(
|
|
233
237
|
logits=logits,
|
|
@@ -91,22 +91,26 @@ def lce_forward(
|
|
|
91
91
|
)
|
|
92
92
|
|
|
93
93
|
hidden_states = outputs[0]
|
|
94
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
95
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
96
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
94
97
|
|
|
98
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
95
99
|
loss = None
|
|
96
100
|
logits = None
|
|
97
101
|
|
|
98
|
-
if self.training and (labels is not None):
|
|
102
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
99
103
|
loss = LigerForCausalLMLoss(
|
|
100
|
-
hidden_states=
|
|
104
|
+
hidden_states=kept_hidden_states,
|
|
101
105
|
lm_head_weight=self.lm_head.weight,
|
|
102
106
|
labels=labels,
|
|
107
|
+
shift_labels=shift_labels,
|
|
103
108
|
hidden_size=self.config.hidden_size,
|
|
104
109
|
**loss_kwargs,
|
|
105
110
|
)
|
|
106
111
|
|
|
107
112
|
else:
|
|
108
|
-
|
|
109
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
113
|
+
logits = self.lm_head(kept_hidden_states)
|
|
110
114
|
|
|
111
115
|
loss = None
|
|
112
116
|
if labels is not None:
|
|
@@ -225,22 +225,26 @@ def lce_forward(
|
|
|
225
225
|
)
|
|
226
226
|
|
|
227
227
|
hidden_states = outputs[0]
|
|
228
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
229
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
230
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
228
231
|
|
|
232
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
229
233
|
logits = None
|
|
230
234
|
loss = None
|
|
231
235
|
# if in training mode, don't materialize logits
|
|
232
|
-
if self.training and (labels is not None):
|
|
236
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
233
237
|
loss = LigerForCausalLMLoss(
|
|
234
|
-
hidden_states=
|
|
238
|
+
hidden_states=kept_hidden_states,
|
|
235
239
|
lm_head_weight=self.lm_head.weight,
|
|
236
240
|
labels=labels,
|
|
241
|
+
shift_labels=shift_labels,
|
|
237
242
|
hidden_size=self.config.hidden_size,
|
|
238
243
|
**loss_kwargs,
|
|
239
244
|
)
|
|
240
245
|
|
|
241
246
|
else: # if in inference mode materialize logits
|
|
242
|
-
|
|
243
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
247
|
+
logits = self.lm_head(kept_hidden_states)
|
|
244
248
|
|
|
245
249
|
loss = None
|
|
246
250
|
if labels is not None:
|
|
@@ -215,22 +215,26 @@ def lce_forward(
|
|
|
215
215
|
)
|
|
216
216
|
|
|
217
217
|
hidden_states = outputs[0]
|
|
218
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
219
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
220
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
218
221
|
|
|
222
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
219
223
|
logits = None
|
|
220
224
|
loss = None
|
|
221
225
|
# if in training mode, don't materialize logits
|
|
222
|
-
if self.training and (labels is not None):
|
|
226
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
223
227
|
loss = LigerForCausalLMLoss(
|
|
224
|
-
hidden_states=
|
|
228
|
+
hidden_states=kept_hidden_states,
|
|
225
229
|
lm_head_weight=self.lm_head.weight,
|
|
226
230
|
labels=labels,
|
|
231
|
+
shift_labels=shift_labels,
|
|
227
232
|
hidden_size=self.config.hidden_size,
|
|
228
233
|
**loss_kwargs,
|
|
229
234
|
)
|
|
230
235
|
|
|
231
236
|
else: # if in inference mode materialize logits
|
|
232
|
-
|
|
233
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
237
|
+
logits = self.lm_head(kept_hidden_states)
|
|
234
238
|
if labels is not None:
|
|
235
239
|
loss = self.loss_function(
|
|
236
240
|
logits=logits,
|
|
@@ -88,22 +88,26 @@ def lce_forward(
|
|
|
88
88
|
)
|
|
89
89
|
|
|
90
90
|
hidden_states = outputs[0]
|
|
91
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
92
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
93
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
91
94
|
|
|
95
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
92
96
|
logits = None
|
|
93
97
|
loss = None
|
|
94
98
|
# if in training mode, don't materialize logits
|
|
95
|
-
if self.training and (labels is not None):
|
|
99
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
96
100
|
loss = LigerForCausalLMLoss(
|
|
97
|
-
hidden_states=
|
|
101
|
+
hidden_states=kept_hidden_states,
|
|
98
102
|
lm_head_weight=self.lm_head.weight,
|
|
99
103
|
labels=labels,
|
|
104
|
+
shift_labels=shift_labels,
|
|
100
105
|
hidden_size=self.config.hidden_size,
|
|
101
106
|
**loss_kwargs,
|
|
102
107
|
)
|
|
103
108
|
|
|
104
109
|
else: # if in inference mode materialize logits
|
|
105
|
-
|
|
106
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
110
|
+
logits = self.lm_head(kept_hidden_states)
|
|
107
111
|
if labels is not None:
|
|
108
112
|
loss = self.loss_function(
|
|
109
113
|
logits=logits,
|