liger-kernel-nightly 0.5.3.dev20250221162633__tar.gz → 0.5.3.dev20250221233257__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_rope.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_tvd.py +8 -11
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/dev/modal/tests.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/training.py +2 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/training_multimodal.py +67 -23
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/tvd.py +6 -7
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/functional.py +4 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/tvd.py +1 -3
- liger_kernel_nightly-0.5.3.dev20250221233257/src/liger_kernel/utils.py +62 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/bf16/test_mini_models.py +52 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/bf16/test_mini_models_with_logits.py +51 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/fp32/test_mini_models.py +61 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/fp32/test_mini_models_with_logits.py +60 -37
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_flex_attention.py +25 -17
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_rope.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_tvd.py +13 -20
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/utils.py +0 -47
- liger_kernel_nightly-0.5.3.dev20250221162633/src/liger_kernel/utils.py +0 -13
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/Makefile +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/setup.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/__init__.py +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633 → liger_kernel_nightly-0.5.3.dev20250221233257}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
|
-
from test.utils import transformers_version_dispatch
|
|
5
4
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
6
5
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
7
6
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
@@ -14,6 +13,7 @@ from utils import run_benchmarks
|
|
|
14
13
|
|
|
15
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
16
15
|
from liger_kernel.utils import infer_device
|
|
16
|
+
from liger_kernel.utils import transformers_version_dispatch
|
|
17
17
|
|
|
18
18
|
device = infer_device()
|
|
19
19
|
|
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
)
|
|
3
|
+
|
|
4
|
+
from utils import QUANTILES
|
|
5
|
+
from utils import SingleBenchmarkRunInput
|
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
|
7
|
+
from utils import _test_memory
|
|
8
|
+
from utils import parse_benchmark_script_args
|
|
9
|
+
from utils import run_benchmarks
|
|
11
10
|
|
|
12
11
|
from liger_kernel.transformers.tvd import LigerTVDLoss
|
|
13
12
|
|
|
@@ -67,9 +66,7 @@ def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
|
|
|
67
66
|
y = fwd()
|
|
68
67
|
y.backward(retain_graph=True)
|
|
69
68
|
|
|
70
|
-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
|
71
|
-
full, quantiles=QUANTILES, rep=100
|
|
72
|
-
)
|
|
69
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
|
|
73
70
|
return SingleBenchmarkRunOutput(
|
|
74
71
|
y_20=ms_20,
|
|
75
72
|
y_50=ms_50,
|
|
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
|
|
|
14
14
|
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@app.function(gpu="A10G", mounts=[repo], timeout=60 *
|
|
17
|
+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
|
|
18
18
|
def liger_tests():
|
|
19
19
|
import subprocess
|
|
20
20
|
|
|
@@ -15,6 +15,7 @@ from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
|
|
15
15
|
class CustomArguments:
|
|
16
16
|
model_name: str = "meta-llama/Meta-Llama-3-8B"
|
|
17
17
|
dataset: str = "tatsu-lab/alpaca"
|
|
18
|
+
max_seq_length: int = 512
|
|
18
19
|
use_liger: bool = False
|
|
19
20
|
|
|
20
21
|
|
|
@@ -65,6 +66,7 @@ def train():
|
|
|
65
66
|
model=model,
|
|
66
67
|
args=training_args,
|
|
67
68
|
data_collator=collator,
|
|
69
|
+
max_seq_length=custom_args.max_seq_length,
|
|
68
70
|
train_dataset=train_dataset,
|
|
69
71
|
eval_dataset=eval_dataset,
|
|
70
72
|
formatting_func=formatting_prompts_func,
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import datasets
|
|
2
6
|
import torch
|
|
3
7
|
import transformers
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
from trl import SFTTrainer, SFTConfig
|
|
7
|
-
from trl.trainer import ConstantLengthDataset
|
|
8
|
+
|
|
9
|
+
from callback import EfficiencyCallback
|
|
8
10
|
from datasets import Image as ImageFeature
|
|
11
|
+
from trl import SFTTrainer
|
|
12
|
+
|
|
9
13
|
from liger_kernel.transformers import monkey_patch
|
|
10
14
|
|
|
11
15
|
|
|
@@ -15,6 +19,8 @@ class CustomArguments:
|
|
|
15
19
|
dataset: str = "HuggingFaceM4/the_cauldron"
|
|
16
20
|
dataset_subset: str = "ai2d"
|
|
17
21
|
dataset_split: str = "train"
|
|
22
|
+
max_seq_length: int = 512
|
|
23
|
+
dataset_text_field: str = "texts"
|
|
18
24
|
use_liger: bool = False
|
|
19
25
|
|
|
20
26
|
|
|
@@ -89,37 +95,75 @@ def _format_for_convo(example, tokenizer):
|
|
|
89
95
|
def train():
|
|
90
96
|
parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments))
|
|
91
97
|
training_args, custom_args = parser.parse_args_into_dataclasses()
|
|
98
|
+
training_args.remove_unused_columns = False # required to not drop the image column
|
|
99
|
+
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
92
100
|
|
|
93
|
-
model, processor, image_token_id = construct_model_and_processor(
|
|
94
|
-
custom_args.model_name, custom_args.use_liger
|
|
95
|
-
)
|
|
101
|
+
model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger)
|
|
96
102
|
|
|
97
|
-
dataset =
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
103
|
+
dataset = (
|
|
104
|
+
datasets.load_dataset(
|
|
105
|
+
custom_args.dataset,
|
|
106
|
+
custom_args.dataset_subset,
|
|
107
|
+
split=custom_args.dataset_split,
|
|
108
|
+
)
|
|
109
|
+
.map(
|
|
110
|
+
_validate_and_extract_the_cauldron,
|
|
111
|
+
batched=True,
|
|
112
|
+
num_proc=min(os.cpu_count(), 16),
|
|
113
|
+
desc="Extracting text and images",
|
|
114
|
+
)
|
|
115
|
+
.map(
|
|
116
|
+
_format_for_convo,
|
|
117
|
+
fn_kwargs={"tokenizer": processor.tokenizer},
|
|
118
|
+
desc="Formatting for convo",
|
|
119
|
+
)
|
|
120
|
+
.cast_column("images", ImageFeature())
|
|
121
|
+
.train_test_split(test_size=0.1)
|
|
101
122
|
)
|
|
102
123
|
|
|
103
|
-
train_dataset
|
|
124
|
+
train_dataset = dataset["train"]
|
|
125
|
+
eval_dataset = dataset["test"]
|
|
104
126
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
127
|
+
def collate_fn(examples):
|
|
128
|
+
"""
|
|
129
|
+
Taken directly from the TRL documentation with minor modifications:
|
|
130
|
+
https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data
|
|
131
|
+
|
|
132
|
+
Modifications:
|
|
133
|
+
1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
|
|
134
|
+
2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
|
|
135
|
+
3. Ignoring image tokens in the loss computation
|
|
136
|
+
"""
|
|
137
|
+
# Get the texts and images
|
|
138
|
+
texts = [example["texts"] for example in examples]
|
|
139
|
+
images = [example["images"] for example in examples]
|
|
140
|
+
|
|
141
|
+
# Tokenize the texts and process the images
|
|
142
|
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
143
|
+
|
|
144
|
+
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
145
|
+
labels = batch["input_ids"].clone()
|
|
146
|
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
|
147
|
+
|
|
148
|
+
# Ignore the image token index in the loss computation
|
|
149
|
+
labels[labels == image_token_id] = -100
|
|
150
|
+
batch["labels"] = labels
|
|
151
|
+
|
|
152
|
+
return batch
|
|
113
153
|
|
|
114
154
|
trainer = SFTTrainer(
|
|
115
155
|
model=model,
|
|
116
|
-
args=
|
|
156
|
+
args=training_args,
|
|
157
|
+
data_collator=collate_fn,
|
|
158
|
+
max_seq_length=custom_args.max_seq_length,
|
|
159
|
+
dataset_text_field=custom_args.dataset_text_field,
|
|
117
160
|
train_dataset=train_dataset,
|
|
118
161
|
eval_dataset=eval_dataset,
|
|
119
|
-
|
|
162
|
+
tokenizer=processor.tokenizer,
|
|
163
|
+
callbacks=[EfficiencyCallback()],
|
|
120
164
|
)
|
|
121
165
|
trainer.train()
|
|
122
166
|
|
|
123
167
|
|
|
124
168
|
if __name__ == "__main__":
|
|
125
|
-
train()
|
|
169
|
+
train()
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.3.
|
|
7
|
+
version = "0.5.3.dev20250221233257"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
import triton
|
|
@@ -178,15 +179,13 @@ class LigerTVDLossFunction(torch.autograd.Function):
|
|
|
178
179
|
"""
|
|
179
180
|
has_label = False
|
|
180
181
|
if shift_labels is not None:
|
|
181
|
-
assert shift_labels.shape == (
|
|
182
|
-
|
|
183
|
-
)
|
|
182
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
183
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
184
|
+
)
|
|
184
185
|
shift_labels = shift_labels.contiguous()
|
|
185
186
|
has_label = True
|
|
186
187
|
|
|
187
|
-
loss, grads = tv_distance_forward_triton(
|
|
188
|
-
p, q, shift_labels, reduction, ignore_index, has_label
|
|
189
|
-
)
|
|
188
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
190
189
|
ctx.save_for_backward(grads)
|
|
191
190
|
return loss
|
|
192
191
|
|
|
@@ -14,6 +14,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
|
|
|
14
14
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
15
15
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
18
19
|
# `weight` and `size_average` are placeholders and not implemented yet
|
|
19
20
|
def liger_cross_entropy(
|
|
@@ -156,6 +157,7 @@ def liger_kl_div(
|
|
|
156
157
|
eps,
|
|
157
158
|
)
|
|
158
159
|
|
|
160
|
+
|
|
159
161
|
def liger_tvd(
|
|
160
162
|
input,
|
|
161
163
|
target,
|
|
@@ -169,7 +171,8 @@ def liger_tvd(
|
|
|
169
171
|
shift_labels,
|
|
170
172
|
reduction,
|
|
171
173
|
ignore_index,
|
|
172
|
-
)
|
|
174
|
+
)
|
|
175
|
+
|
|
173
176
|
|
|
174
177
|
def liger_layer_norm(X, W, B, eps):
|
|
175
178
|
return LigerLayerNormFunction.apply(X, W, B, eps)
|
|
@@ -10,6 +10,4 @@ class LigerTVDLoss(nn.Module):
|
|
|
10
10
|
self.ignore_index = ignore_index
|
|
11
11
|
|
|
12
12
|
def forward(self, p, q, shift_labels=None):
|
|
13
|
-
return LigerTVDLossFunction.apply(
|
|
14
|
-
p, q, shift_labels, self.reduction, self.ignore_index
|
|
15
|
-
)
|
|
13
|
+
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def infer_device():
|
|
5
|
+
"""
|
|
6
|
+
Get current device name based on available devices
|
|
7
|
+
"""
|
|
8
|
+
if torch.cuda.is_available():
|
|
9
|
+
return "cuda"
|
|
10
|
+
elif torch.xpu.is_available():
|
|
11
|
+
return "xpu"
|
|
12
|
+
elif torch.hip.is_available():
|
|
13
|
+
return "hip"
|
|
14
|
+
else:
|
|
15
|
+
return "cpu"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def transformers_version_dispatch(
|
|
19
|
+
required_version: str,
|
|
20
|
+
before_fn,
|
|
21
|
+
after_fn,
|
|
22
|
+
before_args: tuple = (),
|
|
23
|
+
after_args: tuple = (),
|
|
24
|
+
before_kwargs: dict = None,
|
|
25
|
+
after_kwargs: dict = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Dispatches to different functions based on package version comparison.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
required_version: Version to compare against (e.g. "4.48.0")
|
|
32
|
+
before_fn: Function to call if package_version < required_version
|
|
33
|
+
after_fn: Function to call if package_version >= required_version
|
|
34
|
+
before_args: Positional arguments for before_fn
|
|
35
|
+
after_args: Positional arguments for after_fn
|
|
36
|
+
before_kwargs: Keyword arguments for before_fn
|
|
37
|
+
after_kwargs: Keyword arguments for after_fn
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Result from either before_fn or after_fn
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> rotary_emb = transformers_version_dispatch(
|
|
44
|
+
... "4.48.0",
|
|
45
|
+
... LlamaRotaryEmbedding,
|
|
46
|
+
... LlamaRotaryEmbedding,
|
|
47
|
+
... before_args=(head_dim,),
|
|
48
|
+
... after_args=(LlamaConfig(head_dim=head_dim),),
|
|
49
|
+
... before_kwargs={'device': device},
|
|
50
|
+
... after_kwargs={'device': device}
|
|
51
|
+
... )
|
|
52
|
+
"""
|
|
53
|
+
from packaging import version
|
|
54
|
+
from transformers import __version__ as transformers_version
|
|
55
|
+
|
|
56
|
+
before_kwargs = before_kwargs or {}
|
|
57
|
+
after_kwargs = after_kwargs or {}
|
|
58
|
+
|
|
59
|
+
if version.parse(transformers_version) < version.parse(required_version):
|
|
60
|
+
return before_fn(*before_args, **before_kwargs)
|
|
61
|
+
else:
|
|
62
|
+
return after_fn(*after_args, **after_kwargs)
|
|
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
|
|
|
7
7
|
from transformers.models.gemma import GemmaForCausalLM
|
|
8
8
|
from transformers.models.gemma2 import Gemma2Config
|
|
9
9
|
from transformers.models.gemma2 import Gemma2ForCausalLM
|
|
10
|
-
from transformers.models.granite import GraniteConfig
|
|
11
|
-
from transformers.models.granite import GraniteForCausalLM
|
|
12
10
|
from transformers.models.llama import LlamaConfig
|
|
13
11
|
from transformers.models.llama import LlamaForCausalLM
|
|
14
12
|
from transformers.models.mistral import MistralConfig
|
|
@@ -65,44 +63,19 @@ try:
|
|
|
65
63
|
except ImportError:
|
|
66
64
|
QWEN2_VL_AVAILABLE = False
|
|
67
65
|
|
|
66
|
+
try:
|
|
67
|
+
from transformers.models.granite import GraniteConfig
|
|
68
|
+
from transformers.models.granite import GraniteForCausalLM
|
|
69
|
+
|
|
70
|
+
GRANITE_AVAILABLE = True
|
|
71
|
+
except ImportError:
|
|
72
|
+
GRANITE_AVAILABLE = False
|
|
73
|
+
|
|
68
74
|
from liger_kernel.utils import infer_device
|
|
69
75
|
|
|
70
76
|
device = infer_device()
|
|
71
77
|
|
|
72
78
|
MINI_MODEL_SETUPS = {
|
|
73
|
-
"mini_granite3": MiniModelConfig(
|
|
74
|
-
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
75
|
-
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
76
|
-
model_class=GraniteForCausalLM,
|
|
77
|
-
mini_model_config=GraniteConfig(
|
|
78
|
-
attention_bias=False,
|
|
79
|
-
attention_dropout=0.1,
|
|
80
|
-
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
81
|
-
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
82
|
-
bos_token_id=1, # 128000
|
|
83
|
-
eos_token_id=2, # 128001
|
|
84
|
-
hidden_act="silu",
|
|
85
|
-
hidden_size=1024, # 4096
|
|
86
|
-
initializer_range=0.02,
|
|
87
|
-
intermediate_size=2048, # 14336
|
|
88
|
-
max_position_embeddings=8192,
|
|
89
|
-
num_attention_heads=8, # 32
|
|
90
|
-
num_hidden_layers=4, # 32
|
|
91
|
-
num_key_value_heads=2, # 8
|
|
92
|
-
pretraining_tp=1,
|
|
93
|
-
rms_norm_eps=1e-5,
|
|
94
|
-
rope_scaling=None,
|
|
95
|
-
rope_theta=500000.0,
|
|
96
|
-
tie_word_embeddings=False,
|
|
97
|
-
use_cache=True,
|
|
98
|
-
vocab_size=32000, # 128256,
|
|
99
|
-
# At rope backward
|
|
100
|
-
# Eager produces incontiguous dq and dk
|
|
101
|
-
# SDPA produces contiguous dq and incontiguous dk
|
|
102
|
-
# Flash_attn produces contiguous dq and dk
|
|
103
|
-
attn_implementation="sdpa", # default value, pytorch native attention
|
|
104
|
-
),
|
|
105
|
-
),
|
|
106
79
|
"mini_llama3": MiniModelConfig(
|
|
107
80
|
liger_kernel_patch_func=apply_liger_kernel_to_llama,
|
|
108
81
|
liger_kernel_patch_revert_func=revert_liger_kernel_to_llama,
|
|
@@ -418,6 +391,41 @@ if QWEN2_VL_AVAILABLE:
|
|
|
418
391
|
),
|
|
419
392
|
)
|
|
420
393
|
|
|
394
|
+
if GRANITE_AVAILABLE:
|
|
395
|
+
MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
|
|
396
|
+
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
397
|
+
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
398
|
+
model_class=GraniteForCausalLM,
|
|
399
|
+
mini_model_config=GraniteConfig(
|
|
400
|
+
attention_bias=False,
|
|
401
|
+
attention_dropout=0.1,
|
|
402
|
+
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
403
|
+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
404
|
+
bos_token_id=1, # 128000
|
|
405
|
+
eos_token_id=2, # 128001
|
|
406
|
+
hidden_act="silu",
|
|
407
|
+
hidden_size=1024, # 4096
|
|
408
|
+
initializer_range=0.02,
|
|
409
|
+
intermediate_size=2048, # 14336
|
|
410
|
+
max_position_embeddings=8192,
|
|
411
|
+
num_attention_heads=8, # 32
|
|
412
|
+
num_hidden_layers=4, # 32
|
|
413
|
+
num_key_value_heads=2, # 8
|
|
414
|
+
pretraining_tp=1,
|
|
415
|
+
rms_norm_eps=1e-5,
|
|
416
|
+
rope_scaling=None,
|
|
417
|
+
rope_theta=500000.0,
|
|
418
|
+
tie_word_embeddings=False,
|
|
419
|
+
use_cache=True,
|
|
420
|
+
vocab_size=32000, # 128256,
|
|
421
|
+
# At rope backward
|
|
422
|
+
# Eager produces incontiguous dq and dk
|
|
423
|
+
# SDPA produces contiguous dq and incontiguous dk
|
|
424
|
+
# Flash_attn produces contiguous dq and dk
|
|
425
|
+
attn_implementation="sdpa", # default value, pytorch native attention
|
|
426
|
+
),
|
|
427
|
+
)
|
|
428
|
+
|
|
421
429
|
|
|
422
430
|
def create_model(model_name="mini_llama3"):
|
|
423
431
|
"""
|
|
@@ -462,7 +470,8 @@ def run_mini_model(
|
|
|
462
470
|
else:
|
|
463
471
|
kwargs["swiglu"] = True
|
|
464
472
|
|
|
465
|
-
|
|
473
|
+
# fused_linear_cross_entropy is not supported in mini_granite3
|
|
474
|
+
kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False
|
|
466
475
|
kwargs["cross_entropy"] = False
|
|
467
476
|
|
|
468
477
|
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
|
|
@@ -518,7 +527,13 @@ def run_mini_model(
|
|
|
518
527
|
1e-2,
|
|
519
528
|
1e-2,
|
|
520
529
|
1e-2,
|
|
521
|
-
marks=
|
|
530
|
+
marks=[
|
|
531
|
+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
|
532
|
+
pytest.mark.skipif(
|
|
533
|
+
not GRANITE_AVAILABLE,
|
|
534
|
+
reason="Granite not available in this version of transformers",
|
|
535
|
+
),
|
|
536
|
+
],
|
|
522
537
|
),
|
|
523
538
|
pytest.param(
|
|
524
539
|
"mini_mllama",
|
|
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
|
|
|
7
7
|
from transformers.models.gemma import GemmaForCausalLM
|
|
8
8
|
from transformers.models.gemma2 import Gemma2Config
|
|
9
9
|
from transformers.models.gemma2 import Gemma2ForCausalLM
|
|
10
|
-
from transformers.models.granite import GraniteConfig
|
|
11
|
-
from transformers.models.granite import GraniteForCausalLM
|
|
12
10
|
from transformers.models.llama import LlamaConfig
|
|
13
11
|
from transformers.models.llama import LlamaForCausalLM
|
|
14
12
|
from transformers.models.mistral import MistralConfig
|
|
@@ -65,6 +63,14 @@ try:
|
|
|
65
63
|
except ImportError:
|
|
66
64
|
QWEN2_VL_AVAILABLE = False
|
|
67
65
|
|
|
66
|
+
try:
|
|
67
|
+
from transformers.models.granite import GraniteConfig
|
|
68
|
+
from transformers.models.granite import GraniteForCausalLM
|
|
69
|
+
|
|
70
|
+
GRANITE_AVAILABLE = True
|
|
71
|
+
except ImportError:
|
|
72
|
+
GRANITE_AVAILABLE = False
|
|
73
|
+
|
|
68
74
|
from liger_kernel.utils import infer_device
|
|
69
75
|
|
|
70
76
|
device = infer_device()
|
|
@@ -103,40 +109,6 @@ MINI_MODEL_SETUPS = {
|
|
|
103
109
|
attn_implementation="sdpa", # default value, pytorch native attention
|
|
104
110
|
),
|
|
105
111
|
),
|
|
106
|
-
"mini_granite3": MiniModelConfig(
|
|
107
|
-
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
108
|
-
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
109
|
-
model_class=GraniteForCausalLM,
|
|
110
|
-
mini_model_config=GraniteConfig(
|
|
111
|
-
attention_bias=False,
|
|
112
|
-
attention_dropout=0.0,
|
|
113
|
-
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
114
|
-
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
115
|
-
bos_token_id=1, # 128000
|
|
116
|
-
eos_token_id=2, # 128001
|
|
117
|
-
hidden_act="silu",
|
|
118
|
-
hidden_size=1024, # 4096
|
|
119
|
-
initializer_range=0.02,
|
|
120
|
-
intermediate_size=2048, # 14336
|
|
121
|
-
max_position_embeddings=8192,
|
|
122
|
-
num_attention_heads=8, # 32
|
|
123
|
-
num_hidden_layers=4, # 32
|
|
124
|
-
num_key_value_heads=2, # 8
|
|
125
|
-
pretraining_tp=1,
|
|
126
|
-
rms_norm_eps=1e-5,
|
|
127
|
-
rope_scaling=None,
|
|
128
|
-
rope_theta=500000.0,
|
|
129
|
-
tie_word_embeddings=False,
|
|
130
|
-
use_cache=True,
|
|
131
|
-
vocab_size=32000, # 128256,
|
|
132
|
-
logits_scaling=8.0,
|
|
133
|
-
# At rope backward
|
|
134
|
-
# Eager produces incontiguous dq and dk
|
|
135
|
-
# SDPA produces contiguous dq and incontiguous dk
|
|
136
|
-
# Flash_attn produces contiguous dq and dk
|
|
137
|
-
attn_implementation="sdpa", # default value, pytorch native attention
|
|
138
|
-
),
|
|
139
|
-
),
|
|
140
112
|
"mini_qwen2": MiniModelConfig(
|
|
141
113
|
liger_kernel_patch_func=apply_liger_kernel_to_qwen2,
|
|
142
114
|
liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2,
|
|
@@ -419,6 +391,42 @@ if QWEN2_VL_AVAILABLE:
|
|
|
419
391
|
),
|
|
420
392
|
)
|
|
421
393
|
|
|
394
|
+
if GRANITE_AVAILABLE:
|
|
395
|
+
MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
|
|
396
|
+
liger_kernel_patch_func=apply_liger_kernel_to_granite,
|
|
397
|
+
liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
|
|
398
|
+
model_class=GraniteForCausalLM,
|
|
399
|
+
mini_model_config=GraniteConfig(
|
|
400
|
+
attention_bias=False,
|
|
401
|
+
attention_dropout=0.0,
|
|
402
|
+
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
|
|
403
|
+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
|
404
|
+
bos_token_id=1, # 128000
|
|
405
|
+
eos_token_id=2, # 128001
|
|
406
|
+
hidden_act="silu",
|
|
407
|
+
hidden_size=1024, # 4096
|
|
408
|
+
initializer_range=0.02,
|
|
409
|
+
intermediate_size=2048, # 14336
|
|
410
|
+
max_position_embeddings=8192,
|
|
411
|
+
num_attention_heads=8, # 32
|
|
412
|
+
num_hidden_layers=4, # 32
|
|
413
|
+
num_key_value_heads=2, # 8
|
|
414
|
+
pretraining_tp=1,
|
|
415
|
+
rms_norm_eps=1e-5,
|
|
416
|
+
rope_scaling=None,
|
|
417
|
+
rope_theta=500000.0,
|
|
418
|
+
tie_word_embeddings=False,
|
|
419
|
+
use_cache=True,
|
|
420
|
+
vocab_size=32000, # 128256,
|
|
421
|
+
logits_scaling=8.0,
|
|
422
|
+
# At rope backward
|
|
423
|
+
# Eager produces incontiguous dq and dk
|
|
424
|
+
# SDPA produces contiguous dq and incontiguous dk
|
|
425
|
+
# Flash_attn produces contiguous dq and dk
|
|
426
|
+
attn_implementation="sdpa", # default value, pytorch native attention
|
|
427
|
+
),
|
|
428
|
+
)
|
|
429
|
+
|
|
422
430
|
|
|
423
431
|
def create_model(model_name="mini_llama3"):
|
|
424
432
|
"""
|
|
@@ -518,7 +526,13 @@ def run_mini_model(
|
|
|
518
526
|
1e-2, # logits rtol
|
|
519
527
|
1e-2,
|
|
520
528
|
1e-2,
|
|
521
|
-
marks=
|
|
529
|
+
marks=[
|
|
530
|
+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
|
531
|
+
pytest.mark.skipif(
|
|
532
|
+
not GRANITE_AVAILABLE,
|
|
533
|
+
reason="Granite not available in this version of transformers",
|
|
534
|
+
),
|
|
535
|
+
],
|
|
522
536
|
),
|
|
523
537
|
pytest.param(
|
|
524
538
|
"mini_mllama",
|