liger-kernel-nightly 0.5.5.dev20250328142430__tar.gz → 0.5.5.dev20250331042257__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/__init__.py +1 -0
- liger_kernel_nightly-0.5.5.dev20250331042257/src/liger_kernel/transformers/model/llava.py +383 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/monkey_patch.py +84 -1
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/bf16/test_mini_models.py +93 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/bf16/test_mini_models_multimodal.py +137 -2
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/bf16/test_mini_models_with_logits.py +94 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/fp32/test_mini_models.py +90 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/fp32/test_mini_models_multimodal.py +133 -1
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/fp32/test_mini_models_with_logits.py +92 -0
- liger_kernel_nightly-0.5.5.dev20250331042257/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +28 -0
- liger_kernel_nightly-0.5.5.dev20250331042257/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +7 -0
- liger_kernel_nightly-0.5.5.dev20250331042257/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +66 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/utils.py +31 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/Makefile +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/setup.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/kl_div.py +2 -2
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/paligemma.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.5.dev20250328142430 → liger_kernel_nightly-0.5.5.dev20250331042257}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.5.
|
|
7
|
+
version = "0.5.5.dev20250331042257"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -12,6 +12,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma
|
|
|
12
12
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
13
13
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
14
14
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
15
16
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
16
17
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
17
18
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
|
|
9
|
+
from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
|
|
10
|
+
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
|
11
|
+
from transformers.models.llava.modeling_llava import logger
|
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
+
from transformers.utils import is_torchdynamo_compiling
|
|
14
|
+
from transformers.utils import replace_return_docstrings
|
|
15
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
|
+
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
|
21
|
+
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
|
+
def lce_forward_deprecated(
|
|
23
|
+
self,
|
|
24
|
+
input_ids: torch.LongTensor = None,
|
|
25
|
+
pixel_values: torch.FloatTensor = None,
|
|
26
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
27
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
28
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
29
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
30
|
+
vision_feature_layer: Optional[int] = None,
|
|
31
|
+
vision_feature_select_strategy: Optional[str] = None,
|
|
32
|
+
labels: Optional[torch.LongTensor] = None,
|
|
33
|
+
use_cache: Optional[bool] = None,
|
|
34
|
+
output_attentions: Optional[bool] = None,
|
|
35
|
+
output_hidden_states: Optional[bool] = None,
|
|
36
|
+
return_dict: Optional[bool] = None,
|
|
37
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
38
|
+
num_logits_to_keep: int = 0,
|
|
39
|
+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
|
40
|
+
r"""
|
|
41
|
+
Args:
|
|
42
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
43
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
44
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
45
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
46
|
+
|
|
47
|
+
num_logits_to_keep (`int`, *optional*):
|
|
48
|
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
|
49
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
50
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
>>> from PIL import Image
|
|
59
|
+
>>> import requests
|
|
60
|
+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
61
|
+
|
|
62
|
+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
63
|
+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
64
|
+
|
|
65
|
+
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
|
66
|
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
67
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
68
|
+
|
|
69
|
+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
70
|
+
|
|
71
|
+
>>> # Generate
|
|
72
|
+
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
73
|
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
74
|
+
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
|
75
|
+
```"""
|
|
76
|
+
|
|
77
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
78
|
+
output_hidden_states = (
|
|
79
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
80
|
+
)
|
|
81
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
82
|
+
vision_feature_layer = (
|
|
83
|
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
84
|
+
)
|
|
85
|
+
vision_feature_select_strategy = (
|
|
86
|
+
vision_feature_select_strategy
|
|
87
|
+
if vision_feature_select_strategy is not None
|
|
88
|
+
else self.config.vision_feature_select_strategy
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
92
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
93
|
+
|
|
94
|
+
if pixel_values is not None and inputs_embeds is not None:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
legacy_processing = False
|
|
100
|
+
if inputs_embeds is None:
|
|
101
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
102
|
+
|
|
103
|
+
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
|
104
|
+
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
|
105
|
+
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
|
106
|
+
legacy_processing = (
|
|
107
|
+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
|
108
|
+
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
|
109
|
+
|
|
110
|
+
image_features = None
|
|
111
|
+
if pixel_values is not None:
|
|
112
|
+
image_features = self.get_image_features(
|
|
113
|
+
pixel_values=pixel_values,
|
|
114
|
+
vision_feature_layer=vision_feature_layer,
|
|
115
|
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if legacy_processing and image_features is not None:
|
|
119
|
+
logger.warning_once(
|
|
120
|
+
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
|
121
|
+
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
|
122
|
+
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
|
123
|
+
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
124
|
+
)
|
|
125
|
+
# prefill stage vs decoding stage (legacy behavior copied)
|
|
126
|
+
if input_ids.shape[1] != 1:
|
|
127
|
+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
|
128
|
+
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
129
|
+
)
|
|
130
|
+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
|
131
|
+
else:
|
|
132
|
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
133
|
+
# that are set to 0
|
|
134
|
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
135
|
+
|
|
136
|
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
|
137
|
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
|
138
|
+
|
|
139
|
+
# Get the target length
|
|
140
|
+
target_length = input_ids.shape[1]
|
|
141
|
+
past_length = first_layer_past_key_value.shape[-1]
|
|
142
|
+
|
|
143
|
+
extended_attention_mask = torch.ones(
|
|
144
|
+
(attention_mask.shape[0], past_length),
|
|
145
|
+
dtype=attention_mask.dtype,
|
|
146
|
+
device=attention_mask.device,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Filter out only the tokens that can be un-attended, this can happen
|
|
150
|
+
# if one uses Llava + Fused modules where the cache on the
|
|
151
|
+
# first iteration is already big enough, or if one passes custom cache
|
|
152
|
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
|
153
|
+
new_batch_index = batch_index[valid_indices]
|
|
154
|
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
|
155
|
+
|
|
156
|
+
# Zero-out the places where we don't need to attend
|
|
157
|
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
|
158
|
+
|
|
159
|
+
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
|
160
|
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
161
|
+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
|
162
|
+
|
|
163
|
+
# TODO: @raushan retain only the new behavior after v4.47
|
|
164
|
+
elif image_features is not None:
|
|
165
|
+
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
|
166
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
167
|
+
|
|
168
|
+
if n_image_tokens != n_image_features:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
171
|
+
)
|
|
172
|
+
special_image_mask = (
|
|
173
|
+
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
174
|
+
)
|
|
175
|
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
176
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
177
|
+
|
|
178
|
+
outputs = self.language_model.model(
|
|
179
|
+
attention_mask=attention_mask,
|
|
180
|
+
position_ids=position_ids,
|
|
181
|
+
past_key_values=past_key_values,
|
|
182
|
+
inputs_embeds=inputs_embeds,
|
|
183
|
+
use_cache=use_cache,
|
|
184
|
+
output_attentions=output_attentions,
|
|
185
|
+
output_hidden_states=output_hidden_states,
|
|
186
|
+
return_dict=return_dict,
|
|
187
|
+
cache_position=cache_position,
|
|
188
|
+
num_logits_to_keep=num_logits_to_keep,
|
|
189
|
+
)
|
|
190
|
+
hidden_states = outputs[0]
|
|
191
|
+
|
|
192
|
+
loss = None
|
|
193
|
+
logits = None
|
|
194
|
+
|
|
195
|
+
if self.training and (labels is not None):
|
|
196
|
+
# Shift so that tokens < n predict n
|
|
197
|
+
if attention_mask is not None:
|
|
198
|
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
199
|
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
200
|
+
shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
|
|
201
|
+
shift_hidden_states = hidden_states[..., :-1, :][
|
|
202
|
+
shift_attention_mask.to(hidden_states.device) != 0
|
|
203
|
+
].contiguous()
|
|
204
|
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
|
205
|
+
else:
|
|
206
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
207
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
208
|
+
|
|
209
|
+
lce = LigerFusedLinearCrossEntropyLoss()
|
|
210
|
+
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
|
|
211
|
+
|
|
212
|
+
if not return_dict:
|
|
213
|
+
# NOTE: This part has not been tested.
|
|
214
|
+
output = outputs[1:]
|
|
215
|
+
return (loss,) + output if loss is not None else output
|
|
216
|
+
|
|
217
|
+
return LlavaCausalLMOutputWithPast(
|
|
218
|
+
loss=loss,
|
|
219
|
+
logits=logits,
|
|
220
|
+
past_key_values=outputs.past_key_values,
|
|
221
|
+
hidden_states=outputs.hidden_states,
|
|
222
|
+
attentions=outputs.attentions,
|
|
223
|
+
image_hidden_states=image_features if pixel_values is not None else None,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
|
228
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
229
|
+
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
230
|
+
def lce_forward(
|
|
231
|
+
self,
|
|
232
|
+
input_ids: torch.LongTensor = None,
|
|
233
|
+
pixel_values: torch.FloatTensor = None,
|
|
234
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
235
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
236
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
237
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
238
|
+
vision_feature_layer: Optional[int] = None,
|
|
239
|
+
vision_feature_select_strategy: Optional[str] = None,
|
|
240
|
+
labels: Optional[torch.LongTensor] = None,
|
|
241
|
+
use_cache: Optional[bool] = None,
|
|
242
|
+
output_attentions: Optional[bool] = None,
|
|
243
|
+
output_hidden_states: Optional[bool] = None,
|
|
244
|
+
return_dict: Optional[bool] = None,
|
|
245
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
246
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
247
|
+
image_sizes: torch.Tensor = None,
|
|
248
|
+
**lm_kwargs,
|
|
249
|
+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
|
250
|
+
r"""
|
|
251
|
+
Args:
|
|
252
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
253
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
254
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
255
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
256
|
+
|
|
257
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
258
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
259
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
260
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
261
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
262
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
|
|
267
|
+
Example:
|
|
268
|
+
|
|
269
|
+
```python
|
|
270
|
+
>>> from PIL import Image
|
|
271
|
+
>>> import requests
|
|
272
|
+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
273
|
+
|
|
274
|
+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
275
|
+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
276
|
+
|
|
277
|
+
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
|
278
|
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
279
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
280
|
+
|
|
281
|
+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
282
|
+
|
|
283
|
+
>>> # Generate
|
|
284
|
+
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
285
|
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
286
|
+
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
|
287
|
+
```"""
|
|
288
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
289
|
+
output_hidden_states = (
|
|
290
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
291
|
+
)
|
|
292
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
293
|
+
vision_feature_layer = (
|
|
294
|
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
295
|
+
)
|
|
296
|
+
vision_feature_select_strategy = (
|
|
297
|
+
vision_feature_select_strategy
|
|
298
|
+
if vision_feature_select_strategy is not None
|
|
299
|
+
else self.config.vision_feature_select_strategy
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
303
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
304
|
+
|
|
305
|
+
if pixel_values is not None and inputs_embeds is not None:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if inputs_embeds is None:
|
|
311
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
312
|
+
|
|
313
|
+
if pixel_values is not None:
|
|
314
|
+
image_features = self.get_image_features(
|
|
315
|
+
pixel_values=pixel_values,
|
|
316
|
+
vision_feature_layer=vision_feature_layer,
|
|
317
|
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
318
|
+
image_sizes=image_sizes,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
322
|
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
323
|
+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
324
|
+
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
325
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
326
|
+
raise ValueError(
|
|
327
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
328
|
+
)
|
|
329
|
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
330
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
331
|
+
|
|
332
|
+
outputs = self.language_model.model(
|
|
333
|
+
attention_mask=attention_mask,
|
|
334
|
+
position_ids=position_ids,
|
|
335
|
+
past_key_values=past_key_values,
|
|
336
|
+
inputs_embeds=inputs_embeds,
|
|
337
|
+
use_cache=use_cache,
|
|
338
|
+
output_attentions=output_attentions,
|
|
339
|
+
output_hidden_states=output_hidden_states,
|
|
340
|
+
return_dict=return_dict,
|
|
341
|
+
cache_position=cache_position,
|
|
342
|
+
logits_to_keep=logits_to_keep,
|
|
343
|
+
**lm_kwargs,
|
|
344
|
+
)
|
|
345
|
+
hidden_states = outputs[0]
|
|
346
|
+
|
|
347
|
+
loss = None
|
|
348
|
+
logits = None
|
|
349
|
+
|
|
350
|
+
if self.training and (labels is not None):
|
|
351
|
+
# Shift so that tokens < n predict n
|
|
352
|
+
if attention_mask is not None:
|
|
353
|
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
354
|
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
355
|
+
shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
|
|
356
|
+
shift_hidden_states = hidden_states[..., :-1, :][
|
|
357
|
+
shift_attention_mask.to(hidden_states.device) != 0
|
|
358
|
+
].contiguous()
|
|
359
|
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
|
360
|
+
else:
|
|
361
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
362
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
363
|
+
|
|
364
|
+
lce = LigerFusedLinearCrossEntropyLoss()
|
|
365
|
+
loss = lce(
|
|
366
|
+
self.language_model.lm_head.weight,
|
|
367
|
+
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
|
|
368
|
+
shift_labels.view(-1).to(shift_hidden_states.device),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
if not return_dict:
|
|
372
|
+
# NOTE: This part has not been tested.
|
|
373
|
+
output = outputs[1:]
|
|
374
|
+
return (loss,) + output if loss is not None else output
|
|
375
|
+
|
|
376
|
+
return LlavaCausalLMOutputWithPast(
|
|
377
|
+
loss=loss,
|
|
378
|
+
logits=logits,
|
|
379
|
+
past_key_values=outputs.past_key_values,
|
|
380
|
+
hidden_states=outputs.hidden_states,
|
|
381
|
+
attentions=outputs.attentions,
|
|
382
|
+
image_hidden_states=image_features if pixel_values is not None else None,
|
|
383
|
+
)
|
|
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
|
|
|
19
19
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
20
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
21
21
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
22
|
+
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
23
|
+
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
|
|
22
24
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
23
25
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
24
26
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
@@ -57,7 +59,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
57
59
|
|
|
58
60
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
59
61
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
60
|
-
module.hidden_size = module
|
|
62
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
63
|
+
|
|
61
64
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
62
65
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
63
66
|
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
@@ -224,6 +227,85 @@ def apply_liger_kernel_to_llama(
|
|
|
224
227
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
225
228
|
|
|
226
229
|
|
|
230
|
+
def apply_liger_kernel_to_llava(
|
|
231
|
+
cross_entropy: bool = False,
|
|
232
|
+
fused_linear_cross_entropy: bool = True,
|
|
233
|
+
model: PreTrainedModel = None,
|
|
234
|
+
**kwargs,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
|
|
238
|
+
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
|
|
239
|
+
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
|
|
240
|
+
NOTE: Llava is not available in transformers<4.36.0
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
244
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
245
|
+
fused_linear_cross_entropy (bool):
|
|
246
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
247
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
248
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
249
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
250
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
251
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
252
|
+
loaded. Default is None.
|
|
253
|
+
"""
|
|
254
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
255
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
from transformers.models.llava import modeling_llava
|
|
259
|
+
|
|
260
|
+
if cross_entropy:
|
|
261
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
262
|
+
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
263
|
+
if fused_linear_cross_entropy:
|
|
264
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
265
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
266
|
+
else: # if version < 4.49.0
|
|
267
|
+
logger.warning(
|
|
268
|
+
"Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
|
|
269
|
+
)
|
|
270
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
271
|
+
|
|
272
|
+
if model is not None:
|
|
273
|
+
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
274
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
275
|
+
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
276
|
+
|
|
277
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
|
|
278
|
+
if text_liger_fn:
|
|
279
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
280
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
281
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
282
|
+
|
|
283
|
+
if remain_params:
|
|
284
|
+
logger.warning(
|
|
285
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
286
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
287
|
+
)
|
|
288
|
+
text_kwargs["model"] = model.language_model
|
|
289
|
+
text_liger_fn(**text_kwargs)
|
|
290
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
291
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
292
|
+
|
|
293
|
+
if vision_liger_fn:
|
|
294
|
+
accept_params = inspect.signature(vision_liger_fn).parameters
|
|
295
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
296
|
+
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
297
|
+
|
|
298
|
+
if remain_params:
|
|
299
|
+
logger.warning(
|
|
300
|
+
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
301
|
+
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
302
|
+
)
|
|
303
|
+
vision_kwargs["model"] = model.vision_tower
|
|
304
|
+
vision_liger_fn(**vision_kwargs)
|
|
305
|
+
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
306
|
+
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
307
|
+
|
|
308
|
+
|
|
227
309
|
def apply_liger_kernel_to_mllama(
|
|
228
310
|
rope: bool = True,
|
|
229
311
|
cross_entropy: bool = False,
|
|
@@ -1071,6 +1153,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1071
1153
|
"gemma": apply_liger_kernel_to_gemma,
|
|
1072
1154
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
1073
1155
|
"llama": apply_liger_kernel_to_llama,
|
|
1156
|
+
"llava": apply_liger_kernel_to_llava,
|
|
1074
1157
|
"granite": apply_liger_kernel_to_granite,
|
|
1075
1158
|
"mllama": apply_liger_kernel_to_mllama,
|
|
1076
1159
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
@@ -162,6 +162,7 @@ src/liger_kernel/transformers/model/__init__.py
|
|
|
162
162
|
src/liger_kernel/transformers/model/gemma.py
|
|
163
163
|
src/liger_kernel/transformers/model/gemma2.py
|
|
164
164
|
src/liger_kernel/transformers/model/llama.py
|
|
165
|
+
src/liger_kernel/transformers/model/llava.py
|
|
165
166
|
src/liger_kernel/transformers/model/loss_utils.py
|
|
166
167
|
src/liger_kernel/transformers/model/mistral.py
|
|
167
168
|
src/liger_kernel/transformers/model/mixtral.py
|
|
@@ -203,6 +204,9 @@ test/convergence/fp32/test_mini_models_multimodal.py
|
|
|
203
204
|
test/convergence/fp32/test_mini_models_with_logits.py
|
|
204
205
|
test/resources/tiny_shakespeare.txt
|
|
205
206
|
test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json
|
|
207
|
+
test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json
|
|
208
|
+
test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json
|
|
209
|
+
test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json
|
|
206
210
|
test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json
|
|
207
211
|
test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json
|
|
208
212
|
test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json
|