liger-kernel-nightly 0.5.3.dev20250221162633__tar.gz → 0.5.3.dev20250221230243__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_tvd.py +8 -11
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/dev/modal/tests.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/training.py +2 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/training_multimodal.py +67 -23
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/tvd.py +6 -7
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/functional.py +4 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/tvd.py +1 -3
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/utils.py +2 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/test_mini_models.py +52 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/test_mini_models_with_logits.py +51 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/test_mini_models.py +61 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/test_mini_models_with_logits.py +60 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_flex_attention.py +25 -17
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_tvd.py +13 -20
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/Makefile +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/setup.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/__init__.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/utils.py +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
)
|
|
3
|
+
|
|
4
|
+
from utils import QUANTILES
|
|
5
|
+
from utils import SingleBenchmarkRunInput
|
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
|
7
|
+
from utils import _test_memory
|
|
8
|
+
from utils import parse_benchmark_script_args
|
|
9
|
+
from utils import run_benchmarks
|
|
11
10
|
|
|
12
11
|
from liger_kernel.transformers.tvd import LigerTVDLoss
|
|
13
12
|
|
|
@@ -67,9 +66,7 @@ def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
|
67
66
|
y = fwd()
|
|
68
67
|
y.backward(retain_graph=True)
|
|
69
68
|
|
|
70
|
-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
|
71
|
-
full, quantiles=QUANTILES, rep=100
|
|
72
|
-
)
|
|
69
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
|
|
73
70
|
return SingleBenchmarkRunOutput(
|
|
74
71
|
y_20=ms_20,
|
|
75
72
|
y_50=ms_50,
|
|
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
|
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
|
|
18
18
|
def liger_tests():
|
|
19
19
|
import subprocess
|
|
20
20
|
|
|
@@ -15,6 +15,7 @@ from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
|
|
15
15
|
class CustomArguments:
|
|
16
16
|
model_name: str = "meta-llama/Meta-Llama-3-8B"
|
|
17
17
|
dataset: str = "tatsu-lab/alpaca"
|
|
18
|
+
max_seq_length: int = 512
|
|
18
19
|
use_liger: bool = False
|
|
19
20
|
|
|
20
21
|
|
|
@@ -65,6 +66,7 @@ def train():
|
|
|
65
66
|
model=model,
|
|
66
67
|
args=training_args,
|
|
67
68
|
data_collator=collator,
|
|
69
|
+
max_seq_length=custom_args.max_seq_length,
|
|
68
70
|
train_dataset=train_dataset,
|
|
69
71
|
eval_dataset=eval_dataset,
|
|
70
72
|
formatting_func=formatting_prompts_func,
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import datasets
|
|
2
6
|
import torch
|
|
3
7
|
import transformers
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
from trl import SFTTrainer, SFTConfig
|
|
7
|
-
from trl.trainer import ConstantLengthDataset
|
|
8
|
+
|
|
9
|
+
from callback import EfficiencyCallback
|
|
8
10
|
from datasets import Image as ImageFeature
|
|
11
|
+
from trl import SFTTrainer
|
|
12
|
+
|
|
9
13
|
from liger_kernel.transformers import monkey_patch
|
|
10
14
|
|
|
11
15
|
|
|
@@ -15,6 +19,8 @@ class CustomArguments:
|
|
|
15
19
|
dataset: str = "HuggingFaceM4/the_cauldron"
|
|
16
20
|
dataset_subset: str = "ai2d"
|
|
17
21
|
dataset_split: str = "train"
|
|
22
|
+
max_seq_length: int = 512
|
|
23
|
+
dataset_text_field: str = "texts"
|
|
18
24
|
use_liger: bool = False
|
|
19
25
|
|
|
20
26
|
|
|
@@ -89,37 +95,75 @@ def _format_for_convo(example, tokenizer):
|
|
|
89
95
|
def train():
|
|
90
96
|
parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments))
|
|
91
97
|
training_args, custom_args = parser.parse_args_into_dataclasses()
|
|
98
|
+
training_args.remove_unused_columns = False # required to not drop the image column
|
|
99
|
+
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
92
100
|
|
|
93
|
-
model, processor, image_token_id = construct_model_and_processor(
|
|
94
|
-
custom_args.model_name, custom_args.use_liger
|
|
95
|
-
)
|
|
101
|
+
model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger)
|
|
96
102
|
|
|
97
|
-
dataset =
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
103
|
+
dataset = (
|
|
104
|
+
datasets.load_dataset(
|
|
105
|
+
custom_args.dataset,
|
|
106
|
+
custom_args.dataset_subset,
|
|
107
|
+
split=custom_args.dataset_split,
|
|
108
|
+
)
|
|
109
|
+
.map(
|
|
110
|
+
_validate_and_extract_the_cauldron,
|
|
111
|
+
batched=True,
|
|
112
|
+
num_proc=min(os.cpu_count(), 16),
|
|
113
|
+
desc="Extracting text and images",
|
|
114
|
+
)
|
|
115
|
+
.map(
|
|
116
|
+
_format_for_convo,
|
|
117
|
+
fn_kwargs={"tokenizer": processor.tokenizer},
|
|
118
|
+
desc="Formatting for convo",
|
|
119
|
+
)
|
|
120
|
+
.cast_column("images", ImageFeature())
|
|
121
|
+
.train_test_split(test_size=0.1)
|
|
101
122
|
)
|
|
102
123
|
|
|
103
|
-
train_dataset
|
|
124
|
+
train_dataset = dataset["train"]
|
|
125
|
+
eval_dataset = dataset["test"]
|
|
104
126
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
127
|
+
def collate_fn(examples):
|
|
128
|
+
"""
|
|
129
|
+
Taken directly from the TRL documentation with minor modifications:
|
|
130
|
+
https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data
|
|
131
|
+
|
|
132
|
+
Modifications:
|
|
133
|
+
1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
|
|
134
|
+
2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
|
|
135
|
+
3. Ignoring image tokens in the loss computation
|
|
136
|
+
"""
|
|
137
|
+
# Get the texts and images
|
|
138
|
+
texts = [example["texts"] for example in examples]
|
|
139
|
+
images = [example["images"] for example in examples]
|
|
140
|
+
|
|
141
|
+
# Tokenize the texts and process the images
|
|
142
|
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
143
|
+
|
|
144
|
+
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
145
|
+
labels = batch["input_ids"].clone()
|
|
146
|
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
|
147
|
+
|
|
148
|
+
# Ignore the image token index in the loss computation
|
|
149
|
+
labels[labels == image_token_id] = -100
|
|
150
|
+
batch["labels"] = labels
|
|
151
|
+
|
|
152
|
+
return batch
|
|
113
153
|
|
|
114
154
|
trainer = SFTTrainer(
|
|
115
155
|
model=model,
|
|
116
|
-
args=
|
|
156
|
+
args=training_args,
|
|
157
|
+
data_collator=collate_fn,
|
|
158
|
+
max_seq_length=custom_args.max_seq_length,
|
|
159
|
+
dataset_text_field=custom_args.dataset_text_field,
|
|
117
160
|
train_dataset=train_dataset,
|
|
118
161
|
eval_dataset=eval_dataset,
|
|
119
|
-
|
|
162
|
+
tokenizer=processor.tokenizer,
|
|
163
|
+
callbacks=[EfficiencyCallback()],
|
|
120
164
|
)
|
|
121
165
|
trainer.train()
|
|
122
166
|
|
|
123
167
|
|
|
124
168
|
if __name__ == "__main__":
|
|
125
|
-
train()
|
|
169
|
+
train()
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.3.
|
|
7
|
+
version = "0.5.3.dev20250221230243"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
import triton
|
|
@@ -178,15 +179,13 @@ class LigerTVDLossFunction(torch.autograd.Function):
|
|
|
178
179
|
"""
|
|
179
180
|
has_label = False
|
|
180
181
|
if shift_labels is not None:
|
|
181
|
-
assert shift_labels.shape == (
|
|
182
|
-
|
|
183
|
-
)
|
|
182
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
183
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
184
|
+
)
|
|
184
185
|
shift_labels = shift_labels.contiguous()
|
|
185
186
|
has_label = True
|
|
186
187
|
|
|
187
|
-
loss, grads = tv_distance_forward_triton(
|
|
188
|
-
p, q, shift_labels, reduction, ignore_index, has_label
|
|
189
|
-
)
|
|
188
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
190
189
|
ctx.save_for_backward(grads)
|
|
191
190
|
return loss
|
|
192
191
|
|
|
@@ -14,6 +14,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
|
|
|
14
14
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
15
15
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
18
19
|
# `weight` and `size_average` are placeholders and not implemented yet
|
|
19
20
|
def liger_cross_entropy(
|
|
@@ -156,6 +157,7 @@ def liger_kl_div(
|
|
|
156
157
|
eps,
|
|
157
158
|
)
|
|
158
159
|
|
|
160
|
+
|
|
159
161
|
def liger_tvd(
|
|
160
162
|
input,
|
|
161
163
|
target,
|
|
@@ -169,7 +171,8 @@ def liger_tvd(
|
|
|
169
171
|
shift_labels,
|
|
170
172
|
reduction,
|
|
171
173
|
ignore_index,
|
|
172
|
-
)
|
|
174
|
+
)
|
|
175
|
+
|
|
173
176
|
|
|
174
177
|
def liger_layer_norm(X, W, B, eps):
|
|
175
178
|
return LigerLayerNormFunction.apply(X, W, B, eps)
|
|
@@ -10,6 +10,4 @@ class LigerTVDLoss(nn.Module):
|
|
|
10
10
|
self.ignore_index = ignore_index
|
|
11
11
|
|
|
12
12
|
def forward(self, p, q, shift_labels=None):
|
|
13
|
-
return LigerTVDLossFunction.apply(
|
|
14
|
-
p, q, shift_labels, self.reduction, self.ignore_index
|
|
15
|
-
)
|
|
13
|
+
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
|
|
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
|
|
|
7
7
|
from transformers.models.gemma import GemmaForCausalLM
|
|
8
8
|
from transformers.models.gemma2 import Gemma2Config
|
|
9
9
|
from transformers.models.gemma2 import Gemma2ForCausalLM
|
|
10
|
-
from transformers.models.granite import GraniteConfig
|
|
11
|
-
from transformers.models.granite import GraniteForCausalLM
|
|
12
10
|
from transformers.models.llama import LlamaConfig
|
|
13
11
|
from transformers.models.llama import LlamaForCausalLM
|
|
14
12
|
from transformers.models.mistral import MistralConfig
|
|
@@ -65,44 +63,19 @@ try:
|
|
|
65
63
|
except ImportError:
|
|
66
64
|
QWEN2_VL_AVAILABLE = False
|
|
67
65
|
|
|
66
|
+
try:
|
|
67
|
+
from transformers.models.granite import GraniteConfig
|
|
68
|
+
from transformers.models.granite import GraniteForCausalLM
|
|
69
|
+
|
|
70
|
+
GRANITE_AVAILABLE = True
|
|
71
|
+
except ImportError:
|
|
72
|
+
GRANITE_AVAILABLE = False
|
|
73
|
+
|
|
68
74
|
from liger_kernel.utils import infer_device
|
|
69
75
|
|
|
70
76
|
device = infer_device()
|
|
71
77
|
|
|
72
78
|
MINI_MODEL_SETUPS = {
|
|
73
|
-
"mini_granite3": MiniModelConfig(
|
|
74
|
-
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
75
|
-
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
76
|
-
model_class=GraniteForCausalLM,
|
|
77
|
-
mini_model_config=GraniteConfig(
|
|
78
|
-
attention_bias=False,
|
|
79
|
-
attention_dropout=0.1,
|
|
80
|
-
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
81
|
-
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
82
|
-
bos_token_id=1, # 128000
|
|
83
|
-
eos_token_id=2, # 128001
|
|
84
|
-
hidden_act="silu",
|
|
85
|
-
hidden_size=1024, # 4096
|
|
86
|
-
initializer_range=0.02,
|
|
87
|
-
intermediate_size=2048, # 14336
|
|
88
|
-
max_position_embeddings=8192,
|
|
89
|
-
num_attention_heads=8, # 32
|
|
90
|
-
num_hidden_layers=4, # 32
|
|
91
|
-
num_key_value_heads=2, # 8
|
|
92
|
-
pretraining_tp=1,
|
|
93
|
-
rms_norm_eps=1e-5,
|
|
94
|
-
rope_scaling=None,
|
|
95
|
-
rope_theta=500000.0,
|
|
96
|
-
tie_word_embeddings=False,
|
|
97
|
-
use_cache=True,
|
|
98
|
-
vocab_size=32000, # 128256,
|
|
99
|
-
# At rope backward
|
|
100
|
-
# Eager produces incontiguous dq and dk
|
|
101
|
-
# SDPA produces contiguous dq and incontiguous dk
|
|
102
|
-
# Flash_attn produces contiguous dq and dk
|
|
103
|
-
attn_implementation="sdpa", # default value, pytorch native attention
|
|
104
|
-
),
|
|
105
|
-
),
|
|
106
79
|
"mini_llama3": MiniModelConfig(
|
|
107
80
|
liger_kernel_patch_func=apply_liger_kernel_to_llama,
|
|
108
81
|
liger_kernel_patch_revert_func=revert_liger_kernel_to_llama,
|
|
@@ -418,6 +391,41 @@ if QWEN2_VL_AVAILABLE:
|
|
|
418
391
|
),
|
|
419
392
|
)
|
|
420
393
|
|
|
394
|
+
if GRANITE_AVAILABLE:
|
|
395
|
+
MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
|
|
396
|
+
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
397
|
+
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
398
|
+
model_class=GraniteForCausalLM,
|
|
399
|
+
mini_model_config=GraniteConfig(
|
|
400
|
+
attention_bias=False,
|
|
401
|
+
attention_dropout=0.1,
|
|
402
|
+
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
403
|
+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
404
|
+
bos_token_id=1, # 128000
|
|
405
|
+
eos_token_id=2, # 128001
|
|
406
|
+
hidden_act="silu",
|
|
407
|
+
hidden_size=1024, # 4096
|
|
408
|
+
initializer_range=0.02,
|
|
409
|
+
intermediate_size=2048, # 14336
|
|
410
|
+
max_position_embeddings=8192,
|
|
411
|
+
num_attention_heads=8, # 32
|
|
412
|
+
num_hidden_layers=4, # 32
|
|
413
|
+
num_key_value_heads=2, # 8
|
|
414
|
+
pretraining_tp=1,
|
|
415
|
+
rms_norm_eps=1e-5,
|
|
416
|
+
rope_scaling=None,
|
|
417
|
+
rope_theta=500000.0,
|
|
418
|
+
tie_word_embeddings=False,
|
|
419
|
+
use_cache=True,
|
|
420
|
+
vocab_size=32000, # 128256,
|
|
421
|
+
# At rope backward
|
|
422
|
+
# Eager produces incontiguous dq and dk
|
|
423
|
+
# SDPA produces contiguous dq and incontiguous dk
|
|
424
|
+
# Flash_attn produces contiguous dq and dk
|
|
425
|
+
attn_implementation="sdpa", # default value, pytorch native attention
|
|
426
|
+
),
|
|
427
|
+
)
|
|
428
|
+
|
|
421
429
|
|
|
422
430
|
def create_model(model_name="mini_llama3"):
|
|
423
431
|
"""
|
|
@@ -462,7 +470,8 @@ def run_mini_model(
|
|
|
462
470
|
else:
|
|
463
471
|
kwargs["swiglu"] = True
|
|
464
472
|
|
|
465
|
-
|
|
473
|
+
# fused_linear_cross_entropy is not supported in mini_granite3
|
|
474
|
+
kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False
|
|
466
475
|
kwargs["cross_entropy"] = False
|
|
467
476
|
|
|
468
477
|
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
|
|
@@ -518,7 +527,13 @@ def run_mini_model(
|
|
|
518
527
|
1e-2,
|
|
519
528
|
1e-2,
|
|
520
529
|
1e-2,
|
|
521
|
-
marks=
|
|
530
|
+
marks=[
|
|
531
|
+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
|
532
|
+
pytest.mark.skipif(
|
|
533
|
+
not GRANITE_AVAILABLE,
|
|
534
|
+
reason="Granite not available in this version of transformers",
|
|
535
|
+
),
|
|
536
|
+
],
|
|
522
537
|
),
|
|
523
538
|
pytest.param(
|
|
524
539
|
"mini_mllama",
|
|
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
|
|
|
7
7
|
from transformers.models.gemma import GemmaForCausalLM
|
|
8
8
|
from transformers.models.gemma2 import Gemma2Config
|
|
9
9
|
from transformers.models.gemma2 import Gemma2ForCausalLM
|
|
10
|
-
from transformers.models.granite import GraniteConfig
|
|
11
|
-
from transformers.models.granite import GraniteForCausalLM
|
|
12
10
|
from transformers.models.llama import LlamaConfig
|
|
13
11
|
from transformers.models.llama import LlamaForCausalLM
|
|
14
12
|
from transformers.models.mistral import MistralConfig
|
|
@@ -65,6 +63,14 @@ try:
|
|
|
65
63
|
except ImportError:
|
|
66
64
|
QWEN2_VL_AVAILABLE = False
|
|
67
65
|
|
|
66
|
+
try:
|
|
67
|
+
from transformers.models.granite import GraniteConfig
|
|
68
|
+
from transformers.models.granite import GraniteForCausalLM
|
|
69
|
+
|
|
70
|
+
GRANITE_AVAILABLE = True
|
|
71
|
+
except ImportError:
|
|
72
|
+
GRANITE_AVAILABLE = False
|
|
73
|
+
|
|
68
74
|
from liger_kernel.utils import infer_device
|
|
69
75
|
|
|
70
76
|
device = infer_device()
|
|
@@ -103,40 +109,6 @@ MINI_MODEL_SETUPS = {
|
|
|
103
109
|
attn_implementation="sdpa", # default value, pytorch native attention
|
|
104
110
|
),
|
|
105
111
|
),
|
|
106
|
-
"mini_granite3": MiniModelConfig(
|
|
107
|
-
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
108
|
-
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
109
|
-
model_class=GraniteForCausalLM,
|
|
110
|
-
mini_model_config=GraniteConfig(
|
|
111
|
-
attention_bias=False,
|
|
112
|
-
attention_dropout=0.0,
|
|
113
|
-
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
114
|
-
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
115
|
-
bos_token_id=1, # 128000
|
|
116
|
-
eos_token_id=2, # 128001
|
|
117
|
-
hidden_act="silu",
|
|
118
|
-
hidden_size=1024, # 4096
|
|
119
|
-
initializer_range=0.02,
|
|
120
|
-
intermediate_size=2048, # 14336
|
|
121
|
-
max_position_embeddings=8192,
|
|
122
|
-
num_attention_heads=8, # 32
|
|
123
|
-
num_hidden_layers=4, # 32
|
|
124
|
-
num_key_value_heads=2, # 8
|
|
125
|
-
pretraining_tp=1,
|
|
126
|
-
rms_norm_eps=1e-5,
|
|
127
|
-
rope_scaling=None,
|
|
128
|
-
rope_theta=500000.0,
|
|
129
|
-
tie_word_embeddings=False,
|
|
130
|
-
use_cache=True,
|
|
131
|
-
vocab_size=32000, # 128256,
|
|
132
|
-
logits_scaling=8.0,
|
|
133
|
-
# At rope backward
|
|
134
|
-
# Eager produces incontiguous dq and dk
|
|
135
|
-
# SDPA produces contiguous dq and incontiguous dk
|
|
136
|
-
# Flash_attn produces contiguous dq and dk
|
|
137
|
-
attn_implementation="sdpa", # default value, pytorch native attention
|
|
138
|
-
),
|
|
139
|
-
),
|
|
140
112
|
"mini_qwen2": MiniModelConfig(
|
|
141
113
|
liger_kernel_patch_func=apply_liger_kernel_to_qwen2,
|
|
142
114
|
liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2,
|
|
@@ -419,6 +391,42 @@ if QWEN2_VL_AVAILABLE:
|
|
|
419
391
|
),
|
|
420
392
|
)
|
|
421
393
|
|
|
394
|
+
if GRANITE_AVAILABLE:
|
|
395
|
+
MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
|
|
396
|
+
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
397
|
+
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
398
|
+
model_class=GraniteForCausalLM,
|
|
399
|
+
mini_model_config=GraniteConfig(
|
|
400
|
+
attention_bias=False,
|
|
401
|
+
attention_dropout=0.0,
|
|
402
|
+
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
403
|
+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
404
|
+
bos_token_id=1, # 128000
|
|
405
|
+
eos_token_id=2, # 128001
|
|
406
|
+
hidden_act="silu",
|
|
407
|
+
hidden_size=1024, # 4096
|
|
408
|
+
initializer_range=0.02,
|
|
409
|
+
intermediate_size=2048, # 14336
|
|
410
|
+
max_position_embeddings=8192,
|
|
411
|
+
num_attention_heads=8, # 32
|
|
412
|
+
num_hidden_layers=4, # 32
|
|
413
|
+
num_key_value_heads=2, # 8
|
|
414
|
+
pretraining_tp=1,
|
|
415
|
+
rms_norm_eps=1e-5,
|
|
416
|
+
rope_scaling=None,
|
|
417
|
+
rope_theta=500000.0,
|
|
418
|
+
tie_word_embeddings=False,
|
|
419
|
+
use_cache=True,
|
|
420
|
+
vocab_size=32000, # 128256,
|
|
421
|
+
logits_scaling=8.0,
|
|
422
|
+
# At rope backward
|
|
423
|
+
# Eager produces incontiguous dq and dk
|
|
424
|
+
# SDPA produces contiguous dq and incontiguous dk
|
|
425
|
+
# Flash_attn produces contiguous dq and dk
|
|
426
|
+
attn_implementation="sdpa", # default value, pytorch native attention
|
|
427
|
+
),
|
|
428
|
+
)
|
|
429
|
+
|
|
422
430
|
|
|
423
431
|
def create_model(model_name="mini_llama3"):
|
|
424
432
|
"""
|
|
@@ -518,7 +526,13 @@ def run_mini_model(
|
|
|
518
526
|
1e-2, # logits rtol
|
|
519
527
|
1e-2,
|
|
520
528
|
1e-2,
|
|
521
|
-
marks=
|
|
529
|
+
marks=[
|
|
530
|
+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
|
531
|
+
pytest.mark.skipif(
|
|
532
|
+
not GRANITE_AVAILABLE,
|
|
533
|
+
reason="Granite not available in this version of transformers",
|
|
534
|
+
),
|
|
535
|
+
],
|
|
522
536
|
),
|
|
523
537
|
pytest.param(
|
|
524
538
|
"mini_mllama",
|
|
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
|
|
|
7
7
|
from transformers.models.gemma import GemmaForCausalLM
|
|
8
8
|
from transformers.models.gemma2 import Gemma2Config
|
|
9
9
|
from transformers.models.gemma2 import Gemma2ForCausalLM
|
|
10
|
-
from transformers.models.granite import GraniteConfig
|
|
11
|
-
from transformers.models.granite import GraniteForCausalLM
|
|
12
10
|
from transformers.models.llama import LlamaConfig
|
|
13
11
|
from transformers.models.llama import LlamaForCausalLM
|
|
14
12
|
from transformers.models.mistral import MistralConfig
|
|
@@ -64,44 +62,19 @@ try:
|
|
|
64
62
|
except ImportError:
|
|
65
63
|
QWEN2_VL_AVAILABLE = False
|
|
66
64
|
|
|
65
|
+
try:
|
|
66
|
+
from transformers.models.granite import GraniteConfig
|
|
67
|
+
from transformers.models.granite import GraniteForCausalLM
|
|
68
|
+
|
|
69
|
+
GRANITE_AVAILABLE = True
|
|
70
|
+
except ImportError:
|
|
71
|
+
GRANITE_AVAILABLE = False
|
|
72
|
+
|
|
67
73
|
from liger_kernel.utils import infer_device
|
|
68
74
|
|
|
69
75
|
device = infer_device()
|
|
70
76
|
|
|
71
77
|
MINI_MODEL_SETUPS = {
|
|
72
|
-
"mini_granite3": MiniModelConfig(
|
|
73
|
-
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
74
|
-
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
75
|
-
model_class=GraniteForCausalLM,
|
|
76
|
-
mini_model_config=GraniteConfig(
|
|
77
|
-
attention_bias=False,
|
|
78
|
-
attention_dropout=0.1,
|
|
79
|
-
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
80
|
-
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
81
|
-
bos_token_id=1, # 128000
|
|
82
|
-
eos_token_id=2, # 128001
|
|
83
|
-
hidden_act="silu",
|
|
84
|
-
hidden_size=1024, # 4096
|
|
85
|
-
initializer_range=0.02,
|
|
86
|
-
intermediate_size=2048, # 14336
|
|
87
|
-
max_position_embeddings=8192,
|
|
88
|
-
num_attention_heads=8, # 32
|
|
89
|
-
num_hidden_layers=4, # 32
|
|
90
|
-
num_key_value_heads=2, # 8
|
|
91
|
-
pretraining_tp=1,
|
|
92
|
-
rms_norm_eps=1e-5,
|
|
93
|
-
rope_scaling=None,
|
|
94
|
-
rope_theta=500000.0,
|
|
95
|
-
tie_word_embeddings=False,
|
|
96
|
-
use_cache=True,
|
|
97
|
-
vocab_size=32000, # 128256,
|
|
98
|
-
# At rope backward
|
|
99
|
-
# Eager produces incontiguous dq and dk
|
|
100
|
-
# SDPA produces contiguous dq and incontiguous dk
|
|
101
|
-
# Flash_attn produces contiguous dq and dk
|
|
102
|
-
attn_implementation="sdpa", # default value, pytorch native attention
|
|
103
|
-
),
|
|
104
|
-
),
|
|
105
78
|
"mini_llama3": MiniModelConfig(
|
|
106
79
|
liger_kernel_patch_func=apply_liger_kernel_to_llama,
|
|
107
80
|
liger_kernel_patch_revert_func=revert_liger_kernel_to_llama,
|
|
@@ -417,6 +390,41 @@ if QWEN2_VL_AVAILABLE:
|
|
|
417
390
|
),
|
|
418
391
|
)
|
|
419
392
|
|
|
393
|
+
if GRANITE_AVAILABLE:
|
|
394
|
+
MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
|
|
395
|
+
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
396
|
+
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
397
|
+
model_class=GraniteForCausalLM,
|
|
398
|
+
mini_model_config=GraniteConfig(
|
|
399
|
+
attention_bias=False,
|
|
400
|
+
attention_dropout=0.1,
|
|
401
|
+
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
402
|
+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
403
|
+
bos_token_id=1, # 128000
|
|
404
|
+
eos_token_id=2, # 128001
|
|
405
|
+
hidden_act="silu",
|
|
406
|
+
hidden_size=1024, # 4096
|
|
407
|
+
initializer_range=0.02,
|
|
408
|
+
intermediate_size=2048, # 14336
|
|
409
|
+
max_position_embeddings=8192,
|
|
410
|
+
num_attention_heads=8, # 32
|
|
411
|
+
num_hidden_layers=4, # 32
|
|
412
|
+
num_key_value_heads=2, # 8
|
|
413
|
+
pretraining_tp=1,
|
|
414
|
+
rms_norm_eps=1e-5,
|
|
415
|
+
rope_scaling=None,
|
|
416
|
+
rope_theta=500000.0,
|
|
417
|
+
tie_word_embeddings=False,
|
|
418
|
+
use_cache=True,
|
|
419
|
+
vocab_size=32000, # 128256,
|
|
420
|
+
# At rope backward
|
|
421
|
+
# Eager produces incontiguous dq and dk
|
|
422
|
+
# SDPA produces contiguous dq and incontiguous dk
|
|
423
|
+
# Flash_attn produces contiguous dq and dk
|
|
424
|
+
attn_implementation="sdpa", # default value, pytorch native attention
|
|
425
|
+
),
|
|
426
|
+
)
|
|
427
|
+
|
|
420
428
|
|
|
421
429
|
def create_model(model_name="mini_llama3"):
|
|
422
430
|
"""
|
|
@@ -461,7 +469,8 @@ def run_mini_model(
|
|
|
461
469
|
else:
|
|
462
470
|
kwargs["swiglu"] = True
|
|
463
471
|
|
|
464
|
-
|
|
472
|
+
# fused_linear_cross_entropy is not supported in mini_granite3
|
|
473
|
+
kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False
|
|
465
474
|
kwargs["cross_entropy"] = False
|
|
466
475
|
|
|
467
476
|
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
|
|
@@ -535,7 +544,22 @@ def run_mini_model(
|
|
|
535
544
|
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
|
|
536
545
|
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
|
|
537
546
|
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
|
|
538
|
-
(
|
|
547
|
+
pytest.param(
|
|
548
|
+
"mini_granite3",
|
|
549
|
+
32,
|
|
550
|
+
1e-4,
|
|
551
|
+
torch.float32,
|
|
552
|
+
1e-8,
|
|
553
|
+
1e-4,
|
|
554
|
+
5e-3,
|
|
555
|
+
1e-5,
|
|
556
|
+
5e-3,
|
|
557
|
+
1e-5,
|
|
558
|
+
marks=pytest.mark.skipif(
|
|
559
|
+
not GRANITE_AVAILABLE,
|
|
560
|
+
reason="Granite not available in this version of transformers",
|
|
561
|
+
),
|
|
562
|
+
),
|
|
539
563
|
],
|
|
540
564
|
)
|
|
541
565
|
def test_mini_model(
|