liger-kernel-nightly 0.5.5.dev20250315175408__tar.gz → 0.5.5.dev20250317215555__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/gemma.py +8 -16
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/gemma2.py +7 -16
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel_nightly-0.5.5.dev20250317215555/src/liger_kernel/transformers/model/loss_utils.py +57 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/mistral.py +9 -10
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/mixtral.py +8 -15
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/mllama.py +8 -15
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/olmo2.py +8 -16
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/paligemma.py +1 -1
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/phi3.py +8 -15
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/qwen2.py +8 -15
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/qwen2_vl.py +9 -10
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/bf16/test_mini_models_multimodal.py +0 -1
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/fp32/test_mini_models_multimodal.py +0 -1
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/Makefile +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/setup.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/test/utils.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.5.
|
|
7
|
+
version = "0.5.5.dev20250317215555"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -14,6 +14,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
14
14
|
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
16
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
@@ -200,22 +201,13 @@ def lce_forward(
|
|
|
200
201
|
loss = None
|
|
201
202
|
# if in training mode, don't materialize logits
|
|
202
203
|
if self.training and (labels is not None):
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
shift_labels = shift_labels.view(-1)
|
|
211
|
-
|
|
212
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
213
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
214
|
-
|
|
215
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
216
|
-
if reduction == "sum":
|
|
217
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
218
|
-
|
|
204
|
+
loss = LigerForCausalLMLoss(
|
|
205
|
+
hidden_states=hidden_states,
|
|
206
|
+
lm_head_weight=self.lm_head.weight,
|
|
207
|
+
labels=labels,
|
|
208
|
+
hidden_size=self.config.hidden_size,
|
|
209
|
+
**loss_kwargs,
|
|
210
|
+
)
|
|
219
211
|
else: # if in inference mode materialize logits
|
|
220
212
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
221
213
|
if labels is not None:
|
|
@@ -15,6 +15,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
19
|
|
|
19
20
|
logger = logging.getLogger(__name__)
|
|
20
21
|
|
|
@@ -212,25 +213,15 @@ def lce_forward(
|
|
|
212
213
|
loss = None
|
|
213
214
|
# if in training mode, don't materialize logits
|
|
214
215
|
if self.training and (labels is not None):
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
# flatten tokens
|
|
221
|
-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
222
|
-
shift_labels = shift_labels.view(-1)
|
|
223
|
-
|
|
224
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
225
|
-
lce = LigerFusedLinearCrossEntropyLoss(
|
|
216
|
+
loss = LigerForCausalLMLoss(
|
|
217
|
+
hidden_states=hidden_states,
|
|
218
|
+
lm_head_weight=self.lm_head.weight,
|
|
219
|
+
labels=labels,
|
|
220
|
+
hidden_size=self.config.hidden_size,
|
|
226
221
|
softcap=self.config.final_logit_softcapping,
|
|
227
|
-
|
|
222
|
+
**loss_kwargs,
|
|
228
223
|
)
|
|
229
224
|
|
|
230
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
231
|
-
if reduction == "sum":
|
|
232
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
233
|
-
|
|
234
225
|
else: # if in inference mode materialize logits
|
|
235
226
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
236
227
|
if self.config.final_logit_softcapping is not None:
|
|
@@ -15,6 +15,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
from transformers.cache_utils import Cache
|
|
@@ -212,21 +213,13 @@ def lce_forward(
|
|
|
212
213
|
loss = None
|
|
213
214
|
# if in training mode, don't materialize logits
|
|
214
215
|
if self.training and (labels is not None):
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
shift_labels = shift_labels.view(-1)
|
|
223
|
-
|
|
224
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
225
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
226
|
-
|
|
227
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
228
|
-
if reduction == "sum":
|
|
229
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
216
|
+
loss = LigerForCausalLMLoss(
|
|
217
|
+
hidden_states=hidden_states,
|
|
218
|
+
lm_head_weight=self.lm_head.weight,
|
|
219
|
+
labels=labels,
|
|
220
|
+
hidden_size=self.config.hidden_size,
|
|
221
|
+
**loss_kwargs,
|
|
222
|
+
)
|
|
230
223
|
|
|
231
224
|
else: # if in inference mode materialize logits
|
|
232
225
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
liger_kernel_nightly-0.5.5.dev20250317215555/src/liger_kernel/transformers/model/loss_utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
import liger_kernel.transformers.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def fixed_fused_linear_cross_entropy(
|
|
7
|
+
hidden_states,
|
|
8
|
+
lm_head_weight,
|
|
9
|
+
target,
|
|
10
|
+
num_items_in_batch: int = None,
|
|
11
|
+
ignore_index: int = -100,
|
|
12
|
+
**kwargs,
|
|
13
|
+
):
|
|
14
|
+
reduction = "sum" if num_items_in_batch is not None else "mean"
|
|
15
|
+
loss = F.liger_fused_linear_cross_entropy(
|
|
16
|
+
hidden_states,
|
|
17
|
+
lm_head_weight,
|
|
18
|
+
target,
|
|
19
|
+
reduction=reduction,
|
|
20
|
+
ignore_index=ignore_index,
|
|
21
|
+
**kwargs,
|
|
22
|
+
)
|
|
23
|
+
if reduction == "sum":
|
|
24
|
+
loss = loss / num_items_in_batch
|
|
25
|
+
|
|
26
|
+
return loss
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def LigerForCausalLMLoss(
|
|
30
|
+
hidden_states,
|
|
31
|
+
lm_head_weight,
|
|
32
|
+
labels,
|
|
33
|
+
hidden_size: int,
|
|
34
|
+
num_items_in_batch: int = None,
|
|
35
|
+
ignore_index: int = -100,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
# Skip upcast since intermediate values for the loss are all fp32 in kernel
|
|
39
|
+
labels = labels.to(hidden_states.device)
|
|
40
|
+
# Shift so that token < n predict n
|
|
41
|
+
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
|
42
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
43
|
+
|
|
44
|
+
# Flatten the tokens
|
|
45
|
+
hidden_states = hidden_states.view(-1, hidden_size)
|
|
46
|
+
shift_labels = shift_labels.view(-1)
|
|
47
|
+
# Enable model parallelism
|
|
48
|
+
shift_labels = shift_labels.to(hidden_states.device)
|
|
49
|
+
loss = fixed_fused_linear_cross_entropy(
|
|
50
|
+
hidden_states,
|
|
51
|
+
lm_head_weight,
|
|
52
|
+
shift_labels,
|
|
53
|
+
num_items_in_batch,
|
|
54
|
+
ignore_index,
|
|
55
|
+
**kwargs,
|
|
56
|
+
)
|
|
57
|
+
return loss
|
|
@@ -13,7 +13,7 @@ from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRIN
|
|
|
13
13
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
14
|
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
|
-
from liger_kernel.transformers.
|
|
16
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
@@ -31,6 +31,7 @@ def lce_forward(
|
|
|
31
31
|
output_hidden_states: Optional[bool] = None,
|
|
32
32
|
return_dict: Optional[bool] = None,
|
|
33
33
|
cache_position: Optional[torch.LongTensor] = None,
|
|
34
|
+
**loss_kwargs,
|
|
34
35
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
35
36
|
r"""
|
|
36
37
|
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -87,15 +88,13 @@ def lce_forward(
|
|
|
87
88
|
logits = None
|
|
88
89
|
|
|
89
90
|
if self.training and (labels is not None):
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
98
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
91
|
+
loss = LigerForCausalLMLoss(
|
|
92
|
+
hidden_states=hidden_states,
|
|
93
|
+
lm_head_weight=self.lm_head.weight,
|
|
94
|
+
labels=labels,
|
|
95
|
+
hidden_size=self.config.hidden_size,
|
|
96
|
+
**loss_kwargs,
|
|
97
|
+
)
|
|
99
98
|
|
|
100
99
|
else:
|
|
101
100
|
logits = self.lm_head(hidden_states)
|
|
@@ -14,6 +14,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
14
14
|
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
16
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
@@ -225,21 +226,13 @@ def lce_forward(
|
|
|
225
226
|
loss = None
|
|
226
227
|
# if in training mode, don't materialize logits
|
|
227
228
|
if self.training and (labels is not None):
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
shift_labels = shift_labels.view(-1)
|
|
236
|
-
|
|
237
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
238
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
239
|
-
|
|
240
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
241
|
-
if reduction == "sum":
|
|
242
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
229
|
+
loss = LigerForCausalLMLoss(
|
|
230
|
+
hidden_states=hidden_states,
|
|
231
|
+
lm_head_weight=self.lm_head.weight,
|
|
232
|
+
labels=labels,
|
|
233
|
+
hidden_size=self.config.hidden_size,
|
|
234
|
+
**loss_kwargs,
|
|
235
|
+
)
|
|
243
236
|
|
|
244
237
|
else: # if in inference mode materialize logits
|
|
245
238
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
@@ -13,6 +13,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
13
13
|
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
15
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
16
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
@@ -215,21 +216,13 @@ def lce_forward(
|
|
|
215
216
|
loss = None
|
|
216
217
|
# if in training mode, don't materialize logits
|
|
217
218
|
if self.training and (labels is not None):
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
shift_labels = shift_labels.view(-1)
|
|
226
|
-
|
|
227
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
228
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
229
|
-
|
|
230
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
231
|
-
if reduction == "sum":
|
|
232
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
219
|
+
loss = LigerForCausalLMLoss(
|
|
220
|
+
hidden_states=hidden_states,
|
|
221
|
+
lm_head_weight=self.lm_head.weight,
|
|
222
|
+
labels=labels,
|
|
223
|
+
hidden_size=self.config.hidden_size,
|
|
224
|
+
**loss_kwargs,
|
|
225
|
+
)
|
|
233
226
|
|
|
234
227
|
else: # if in inference mode materialize logits
|
|
235
228
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
@@ -11,7 +11,7 @@ from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
|
|
|
11
11
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
12
|
from transformers.utils import replace_return_docstrings
|
|
13
13
|
|
|
14
|
-
from liger_kernel.transformers.
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
|
|
@@ -89,21 +89,13 @@ def lce_forward(
|
|
|
89
89
|
loss = None
|
|
90
90
|
# if in training mode, don't materialize logits
|
|
91
91
|
if self.training and (labels is not None):
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
shift_labels = shift_labels.view(-1)
|
|
100
|
-
|
|
101
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
102
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
103
|
-
|
|
104
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
105
|
-
if reduction == "sum":
|
|
106
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
92
|
+
loss = LigerForCausalLMLoss(
|
|
93
|
+
hidden_states=hidden_states,
|
|
94
|
+
lm_head_weight=self.lm_head.weight,
|
|
95
|
+
labels=labels,
|
|
96
|
+
hidden_size=self.config.hidden_size,
|
|
97
|
+
**loss_kwargs,
|
|
98
|
+
)
|
|
107
99
|
|
|
108
100
|
else: # if in inference mode materialize logits
|
|
109
101
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
@@ -172,7 +172,7 @@ def lce_forward(
|
|
|
172
172
|
shift_labels = shift_labels.contiguous()
|
|
173
173
|
|
|
174
174
|
# Flatten hidden state
|
|
175
|
-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
175
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
|
|
176
176
|
shift_labels = shift_labels.view(-1).to(hidden_device)
|
|
177
177
|
|
|
178
178
|
lce = LigerFusedLinearCrossEntropyLoss()
|
|
@@ -13,6 +13,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
13
13
|
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
15
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
16
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
@@ -213,21 +214,13 @@ def lce_forward(
|
|
|
213
214
|
loss = None
|
|
214
215
|
# if in training mode, don't materialize logits
|
|
215
216
|
if self.training and (labels is not None):
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
shift_labels = shift_labels.view(-1)
|
|
224
|
-
|
|
225
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
226
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
227
|
-
|
|
228
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
229
|
-
if reduction == "sum":
|
|
230
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
217
|
+
loss = LigerForCausalLMLoss(
|
|
218
|
+
hidden_states=hidden_states,
|
|
219
|
+
lm_head_weight=self.lm_head.weight,
|
|
220
|
+
labels=labels,
|
|
221
|
+
hidden_size=self.config.hidden_size,
|
|
222
|
+
**loss_kwargs,
|
|
223
|
+
)
|
|
231
224
|
|
|
232
225
|
else: # if in inference mode materialize logits
|
|
233
226
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
@@ -13,6 +13,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
13
13
|
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
15
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
16
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
@@ -199,21 +200,13 @@ def lce_forward(
|
|
|
199
200
|
loss = None
|
|
200
201
|
# if in training mode, don't materialize logits
|
|
201
202
|
if self.training and (labels is not None):
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
shift_labels = shift_labels.view(-1)
|
|
210
|
-
|
|
211
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
212
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
213
|
-
|
|
214
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
215
|
-
if reduction == "sum":
|
|
216
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
203
|
+
loss = LigerForCausalLMLoss(
|
|
204
|
+
hidden_states=hidden_states,
|
|
205
|
+
lm_head_weight=self.lm_head.weight,
|
|
206
|
+
labels=labels,
|
|
207
|
+
hidden_size=self.config.hidden_size,
|
|
208
|
+
**loss_kwargs,
|
|
209
|
+
)
|
|
217
210
|
|
|
218
211
|
else: # if in inference mode materialize logits
|
|
219
212
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
@@ -12,7 +12,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalL
|
|
|
12
12
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
13
|
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
|
-
from liger_kernel.transformers.
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
|
|
@@ -36,6 +36,7 @@ def lce_forward(
|
|
|
36
36
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
37
37
|
cache_position: Optional[torch.LongTensor] = None,
|
|
38
38
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
39
|
+
**loss_kwargs,
|
|
39
40
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
|
40
41
|
r"""
|
|
41
42
|
Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -166,15 +167,13 @@ def lce_forward(
|
|
|
166
167
|
logits = None
|
|
167
168
|
|
|
168
169
|
if self.training and (labels is not None):
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
177
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
170
|
+
loss = LigerForCausalLMLoss(
|
|
171
|
+
hidden_states=hidden_states,
|
|
172
|
+
lm_head_weight=self.lm_head.weight,
|
|
173
|
+
labels=labels,
|
|
174
|
+
hidden_size=self.config.hidden_size,
|
|
175
|
+
**loss_kwargs,
|
|
176
|
+
)
|
|
178
177
|
else:
|
|
179
178
|
logits = self.lm_head(hidden_states)
|
|
180
179
|
if labels is not None:
|
|
@@ -14,7 +14,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutput
|
|
|
14
14
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
|
-
from liger_kernel.transformers.
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
@@ -37,6 +37,7 @@ def lce_forward(
|
|
|
37
37
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
38
38
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
39
39
|
cache_position: Optional[torch.LongTensor] = None,
|
|
40
|
+
**loss_kwargs,
|
|
40
41
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
41
42
|
r"""
|
|
42
43
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -170,15 +171,13 @@ def lce_forward(
|
|
|
170
171
|
logits = None
|
|
171
172
|
|
|
172
173
|
if self.training and (labels is not None):
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
181
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
174
|
+
loss = LigerForCausalLMLoss(
|
|
175
|
+
hidden_states=hidden_states,
|
|
176
|
+
lm_head_weight=self.lm_head.weight,
|
|
177
|
+
labels=labels,
|
|
178
|
+
hidden_size=self.config.hidden_size,
|
|
179
|
+
**loss_kwargs,
|
|
180
|
+
)
|
|
182
181
|
else:
|
|
183
182
|
logits = self.lm_head(hidden_states)
|
|
184
183
|
if labels is not None:
|
|
@@ -159,6 +159,7 @@ src/liger_kernel/transformers/model/__init__.py
|
|
|
159
159
|
src/liger_kernel/transformers/model/gemma.py
|
|
160
160
|
src/liger_kernel/transformers/model/gemma2.py
|
|
161
161
|
src/liger_kernel/transformers/model/llama.py
|
|
162
|
+
src/liger_kernel/transformers/model/loss_utils.py
|
|
162
163
|
src/liger_kernel/transformers/model/mistral.py
|
|
163
164
|
src/liger_kernel/transformers/model/mixtral.py
|
|
164
165
|
src/liger_kernel/transformers/model/mllama.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{liger_kernel_nightly-0.5.5.dev20250315175408 → liger_kernel_nightly-0.5.5.dev20250317215555}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|