liger-kernel-nightly 0.5.2.dev20241223032015__tar.gz → 0.5.2.dev20241223042135__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/Makefile +3 -4
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/benchmarks_visualizer.py +5 -10
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_cpo_loss.py +11 -16
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_cross_entropy.py +10 -13
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_dpo_loss.py +16 -30
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_embedding.py +10 -13
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +16 -29
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_fused_linear_jsd.py +30 -43
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_geglu.py +7 -8
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_group_norm.py +15 -28
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_jsd.py +13 -20
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_kl_div.py +10 -17
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_layer_norm.py +11 -16
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_orpo_loss.py +11 -16
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_qwen2vl_mrope.py +23 -48
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_rms_norm.py +9 -10
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_rope.py +21 -42
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_simpo_loss.py +11 -16
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/benchmark_swiglu.py +8 -11
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/utils.py +16 -25
- liger_kernel_nightly-0.5.2.dev20241223042135/dev/fmt-requirements.txt +1 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/dev/modal/tests.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/dev/modal/tests_bwd.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/alignment/run_orpo.py +4 -4
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/callback.py +16 -34
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/launch_on_modal.py +2 -5
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/training.py +5 -7
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/training_multimodal.py +5 -7
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/lightning/training.py +17 -31
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/callback.py +19 -43
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/medusa_util.py +11 -33
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/train.py +19 -39
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/pyproject.toml +34 -2
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/setup.py +1 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/cpo_loss.py +5 -11
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/dpo_loss.py +1 -4
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/orpo_loss.py +2 -6
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/simpo_loss.py +4 -8
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/env_report.py +4 -11
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/cross_entropy.py +7 -10
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/experimental/embedding.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/fused_linear_jsd.py +11 -29
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/geglu.py +6 -17
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/group_norm.py +11 -28
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/jsd.py +2 -6
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/kl_div.py +4 -7
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/layer_norm.py +3 -5
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/qwen2vl_mrope.py +8 -25
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/rms_norm.py +11 -29
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/rope.py +31 -33
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/swiglu.py +4 -8
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/utils.py +2 -0
- liger_kernel_nightly-0.5.2.dev20241223042135/src/liger_kernel/transformers/__init__.py +23 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/auto_model.py +6 -13
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/cross_entropy.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/experimental/embedding.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/functional.py +2 -6
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/geglu.py +1 -4
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/group_norm.py +3 -9
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/jsd.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/kl_div.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/layer_norm.py +3 -9
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/gemma.py +18 -40
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/gemma2.py +19 -41
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/llama.py +22 -48
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/mistral.py +14 -26
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/mixtral.py +23 -53
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/mllama.py +16 -36
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/phi3.py +18 -40
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/qwen2.py +18 -40
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/qwen2_vl.py +16 -30
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/monkey_patch.py +43 -117
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/rms_norm.py +4 -4
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/rope.py +2 -2
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel_nightly-0.5.2.dev20241223042135/src/liger_kernel/transformers/trainer/__init__.py +4 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel_nightly-0.5.2.dev20241223042135/src/liger_kernel/triton/__init__.py +1 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -2
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/chunked_loss/test_cpo_loss.py +10 -22
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/chunked_loss/test_dpo_loss.py +15 -30
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/chunked_loss/test_orpo_loss.py +12 -23
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/chunked_loss/test_simpo_loss.py +11 -19
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/convergence/test_mini_models.py +54 -76
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/convergence/test_mini_models_multimodal.py +35 -73
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/convergence/test_mini_models_with_logits.py +54 -76
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/scripts/generate_tokenized_dataset.py +4 -12
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_auto_model.py +14 -17
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_cross_entropy.py +32 -90
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_embedding.py +6 -19
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_fused_linear_cross_entropy.py +12 -26
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_fused_linear_jsd.py +41 -66
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_geglu.py +3 -5
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_group_norm.py +5 -17
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_jsd.py +25 -46
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_kl_div.py +5 -11
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_layer_norm.py +2 -6
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_mm_int8int2.py +8 -20
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_monkey_patch.py +153 -331
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_qwen2vl_mrope.py +15 -43
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_rms_norm.py +9 -18
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_rope.py +39 -27
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_swiglu.py +8 -15
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/transformers/test_trainer_integration.py +1 -3
- liger_kernel_nightly-0.5.2.dev20241223042135/test/transformers/test_transformers.py +16 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/utils.py +28 -55
- liger_kernel_nightly-0.5.2.dev20241223032015/.flake8 +0 -10
- liger_kernel_nightly-0.5.2.dev20241223032015/.isort.cfg +0 -2
- liger_kernel_nightly-0.5.2.dev20241223032015/dev/fmt-requirements.txt +0 -3
- liger_kernel_nightly-0.5.2.dev20241223032015/src/liger_kernel/transformers/__init__.py +0 -31
- liger_kernel_nightly-0.5.2.dev20241223032015/src/liger_kernel/transformers/trainer/__init__.py +0 -6
- liger_kernel_nightly-0.5.2.dev20241223032015/src/liger_kernel/triton/__init__.py +0 -3
- liger_kernel_nightly-0.5.2.dev20241223032015/test/transformers/test_transformers.py +0 -18
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/Acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/CONTRIBUTING.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/License.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015 → liger_kernel_nightly-0.5.2.dev20241223042135}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -10,10 +10,9 @@ test:
|
|
10
10
|
# Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
|
11
11
|
# Subsequent commands still run if the previous fails, but return failure at the end
|
12
12
|
checkstyle:
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
if [ $$flake8_status -ne 0 ] || [ $$isort_status -ne 0 ] || [ $$black_status -ne 0 ]; then \
|
13
|
+
ruff check . --fix; ruff_check_status=$$?; \
|
14
|
+
ruff format .; ruff_format_status=$$?; \
|
15
|
+
if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
|
17
16
|
exit 1; \
|
18
17
|
fi
|
19
18
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
|
3
4
|
from argparse import ArgumentParser
|
4
5
|
from dataclasses import dataclass
|
5
6
|
|
@@ -39,9 +40,7 @@ def parse_args() -> VisualizationsConfig:
|
|
39
40
|
VisualizationsConfig: Configuration object for the visualizations script.
|
40
41
|
"""
|
41
42
|
parser = ArgumentParser()
|
42
|
-
parser.add_argument(
|
43
|
-
"--kernel-name", type=str, required=True, help="Kernel name to benchmark"
|
44
|
-
)
|
43
|
+
parser.add_argument("--kernel-name", type=str, required=True, help="Kernel name to benchmark")
|
45
44
|
parser.add_argument(
|
46
45
|
"--metric-name",
|
47
46
|
type=str,
|
@@ -54,9 +53,7 @@ def parse_args() -> VisualizationsConfig:
|
|
54
53
|
required=True,
|
55
54
|
help="Kernel operation mode to visualize (forward/backward/full)",
|
56
55
|
)
|
57
|
-
parser.add_argument(
|
58
|
-
"--display", action="store_true", help="Display the visualization"
|
59
|
-
)
|
56
|
+
parser.add_argument("--display", action="store_true", help="Display the visualization")
|
60
57
|
parser.add_argument(
|
61
58
|
"--overwrite",
|
62
59
|
action="store_true",
|
@@ -126,7 +123,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
|
|
126
123
|
lines = ax.get_lines()
|
127
124
|
colors = [line.get_color() for line in lines]
|
128
125
|
|
129
|
-
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
|
126
|
+
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors, strict=False):
|
130
127
|
# for i, row in group_data.iterrows():
|
131
128
|
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
|
132
129
|
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
|
@@ -145,9 +142,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
|
|
145
142
|
plt.ylabel(ylabel)
|
146
143
|
plt.tight_layout()
|
147
144
|
|
148
|
-
out_path = os.path.join(
|
149
|
-
VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
|
150
|
-
)
|
145
|
+
out_path = os.path.join(VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png")
|
151
146
|
|
152
147
|
if config.display:
|
153
148
|
plt.show()
|
@@ -3,14 +3,13 @@ import sys
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
)
|
6
|
+
|
7
|
+
from utils import QUANTILES
|
8
|
+
from utils import SingleBenchmarkRunInput
|
9
|
+
from utils import SingleBenchmarkRunOutput
|
10
|
+
from utils import _test_memory
|
11
|
+
from utils import parse_benchmark_script_args
|
12
|
+
from utils import run_benchmarks
|
14
13
|
|
15
14
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
16
15
|
from liger_kernel.utils import infer_device
|
@@ -33,9 +32,7 @@ class TorchLMHeadCPO(torch.nn.Module):
|
|
33
32
|
from test.chunked_loss.test_cpo_loss import HFCPOLoss
|
34
33
|
|
35
34
|
super().__init__()
|
36
|
-
self.lin = torch.nn.Linear(
|
37
|
-
in_features=H, out_features=V, bias=False, dtype=dtype
|
38
|
-
)
|
35
|
+
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
|
39
36
|
self.cpo_loss = HFCPOLoss().get_batch_loss_metrics
|
40
37
|
|
41
38
|
def forward(self, x, y):
|
@@ -45,9 +42,7 @@ class TorchLMHeadCPO(torch.nn.Module):
|
|
45
42
|
class LigerLMHeadCPO(torch.nn.Module):
|
46
43
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
47
44
|
super().__init__()
|
48
|
-
self.lin = torch.nn.Linear(
|
49
|
-
in_features=H, out_features=V, bias=False, dtype=dtype
|
50
|
-
)
|
45
|
+
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
|
51
46
|
self.cpo_loss = LigerFusedLinearCPOFunction.apply
|
52
47
|
|
53
48
|
def forward(self, x, y):
|
@@ -180,12 +175,12 @@ if __name__ == "__main__":
|
|
180
175
|
kernel_operation_modes=["forward", "full"],
|
181
176
|
metric_name="speed",
|
182
177
|
metric_unit="ms",
|
183
|
-
**common_configs
|
178
|
+
**common_configs,
|
184
179
|
)
|
185
180
|
run_benchmarks(
|
186
181
|
bench_test_fn=bench_memory_fused_linear_cpo_loss,
|
187
182
|
kernel_operation_modes=["full"],
|
188
183
|
metric_name="memory",
|
189
184
|
metric_unit="MB",
|
190
|
-
**common_configs
|
185
|
+
**common_configs,
|
191
186
|
)
|
@@ -1,14 +1,13 @@
|
|
1
1
|
import torch
|
2
2
|
import triton
|
3
|
+
|
3
4
|
from torch.nn import CrossEntropyLoss
|
4
|
-
from utils import
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
run_benchmarks,
|
11
|
-
)
|
5
|
+
from utils import QUANTILES
|
6
|
+
from utils import SingleBenchmarkRunInput
|
7
|
+
from utils import SingleBenchmarkRunOutput
|
8
|
+
from utils import _test_memory
|
9
|
+
from utils import parse_benchmark_script_args
|
10
|
+
from utils import run_benchmarks
|
12
11
|
|
13
12
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
14
13
|
from liger_kernel.utils import infer_device
|
@@ -86,9 +85,7 @@ def bench_speed_cross_entropy(
|
|
86
85
|
y = fwd()
|
87
86
|
y.backward()
|
88
87
|
|
89
|
-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
90
|
-
full, rep=100, quantiles=QUANTILES
|
91
|
-
)
|
88
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
|
92
89
|
|
93
90
|
return SingleBenchmarkRunOutput(
|
94
91
|
y_20=ms_20,
|
@@ -115,12 +112,12 @@ if __name__ == "__main__":
|
|
115
112
|
kernel_operation_modes=["forward", "full"],
|
116
113
|
metric_name="speed",
|
117
114
|
metric_unit="ms",
|
118
|
-
**common_configs
|
115
|
+
**common_configs,
|
119
116
|
)
|
120
117
|
run_benchmarks(
|
121
118
|
bench_test_fn=bench_memory_cross_entropy,
|
122
119
|
kernel_operation_modes=["full"],
|
123
120
|
metric_name="memory",
|
124
121
|
metric_unit="MB",
|
125
|
-
**common_configs
|
122
|
+
**common_configs,
|
126
123
|
)
|
@@ -1,15 +1,13 @@
|
|
1
|
-
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
|
2
|
-
|
3
1
|
import torch
|
4
2
|
import triton
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
|
4
|
+
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
|
5
|
+
from utils import QUANTILES
|
6
|
+
from utils import SingleBenchmarkRunInput
|
7
|
+
from utils import SingleBenchmarkRunOutput
|
8
|
+
from utils import _test_memory
|
9
|
+
from utils import parse_benchmark_script_args
|
10
|
+
from utils import run_benchmarks
|
13
11
|
|
14
12
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
15
13
|
from liger_kernel.utils import infer_device
|
@@ -28,9 +26,7 @@ class TorchDPOLoss(torch.nn.Module):
|
|
28
26
|
bias: bool = False,
|
29
27
|
):
|
30
28
|
super().__init__()
|
31
|
-
self.lin = torch.nn.Linear(
|
32
|
-
in_features=H, out_features=V, bias=bias, dtype=dtype
|
33
|
-
)
|
29
|
+
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
34
30
|
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)
|
35
31
|
|
36
32
|
def forward(self, x, target):
|
@@ -53,9 +49,7 @@ class LigerDPOLoss(torch.nn.Module):
|
|
53
49
|
bias: bool = False,
|
54
50
|
):
|
55
51
|
super().__init__()
|
56
|
-
self.lin = torch.nn.Linear(
|
57
|
-
in_features=H, out_features=V, bias=bias, dtype=dtype
|
58
|
-
)
|
52
|
+
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
|
59
53
|
self.beta = beta
|
60
54
|
self.ignore_index = ignore_index
|
61
55
|
|
@@ -82,12 +76,8 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
82
76
|
ignore_index = input.extra_benchmark_config["ignore_index"]
|
83
77
|
provider = input.kernel_provider
|
84
78
|
|
85
|
-
torch_dpo_loss = TorchDPOLoss(
|
86
|
-
|
87
|
-
).to(device)
|
88
|
-
liger_dpo_loss = LigerDPOLoss(
|
89
|
-
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
|
90
|
-
).to(device)
|
79
|
+
torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
|
80
|
+
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
|
91
81
|
|
92
82
|
# Input shape: [B, T, H]
|
93
83
|
_input = torch.randn(B, T, H, device=device, dtype=dtype)
|
@@ -129,12 +119,8 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
|
|
129
119
|
provider = input.kernel_provider
|
130
120
|
mode = input.kernel_operation_mode
|
131
121
|
|
132
|
-
torch_dpo_loss = TorchDPOLoss(
|
133
|
-
|
134
|
-
).to(device)
|
135
|
-
liger_dpo_loss = LigerDPOLoss(
|
136
|
-
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
|
137
|
-
).to(device)
|
122
|
+
torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
|
123
|
+
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
|
138
124
|
|
139
125
|
# Input shape: [B, T, H]
|
140
126
|
_input = torch.randn(B, T, H, device=device, dtype=dtype)
|
@@ -215,7 +201,7 @@ if __name__ == "__main__":
|
|
215
201
|
kernel_operation_modes=["forward", "full"],
|
216
202
|
metric_name="speed",
|
217
203
|
metric_unit="ms",
|
218
|
-
**common_configs
|
204
|
+
**common_configs,
|
219
205
|
)
|
220
206
|
|
221
207
|
run_benchmarks(
|
@@ -223,5 +209,5 @@ if __name__ == "__main__":
|
|
223
209
|
kernel_operation_modes=["full"],
|
224
210
|
metric_name="memory",
|
225
211
|
metric_unit="MB",
|
226
|
-
**common_configs
|
212
|
+
**common_configs,
|
227
213
|
)
|
@@ -1,14 +1,13 @@
|
|
1
1
|
import torch
|
2
2
|
import triton
|
3
|
+
|
3
4
|
from torch.nn import Embedding
|
4
|
-
from utils import
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
run_benchmarks,
|
11
|
-
)
|
5
|
+
from utils import QUANTILES
|
6
|
+
from utils import SingleBenchmarkRunInput
|
7
|
+
from utils import SingleBenchmarkRunOutput
|
8
|
+
from utils import _test_memory
|
9
|
+
from utils import parse_benchmark_script_args
|
10
|
+
from utils import run_benchmarks
|
12
11
|
|
13
12
|
from liger_kernel.transformers.experimental.embedding import LigerEmbedding
|
14
13
|
from liger_kernel.utils import infer_device
|
@@ -50,9 +49,7 @@ def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
|
|
50
49
|
if mode == "forward":
|
51
50
|
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
|
52
51
|
elif mode == "full":
|
53
|
-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
54
|
-
full, quantiles=QUANTILES, rep=100
|
55
|
-
)
|
52
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
|
56
53
|
return SingleBenchmarkRunOutput(
|
57
54
|
y_20=ms_20,
|
58
55
|
y_50=ms_50,
|
@@ -118,12 +115,12 @@ if __name__ == "__main__":
|
|
118
115
|
kernel_operation_modes=["forward", "full"],
|
119
116
|
metric_name="speed",
|
120
117
|
metric_unit="ms",
|
121
|
-
**common_configs
|
118
|
+
**common_configs,
|
122
119
|
)
|
123
120
|
run_benchmarks(
|
124
121
|
bench_test_fn=bench_memory_embedding,
|
125
122
|
kernel_operation_modes=["full"],
|
126
123
|
metric_name="memory",
|
127
124
|
metric_unit="MB",
|
128
|
-
**common_configs
|
125
|
+
**common_configs,
|
129
126
|
)
|
@@ -1,17 +1,14 @@
|
|
1
1
|
import torch
|
2
2
|
import triton
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
13
|
-
LigerFusedLinearCrossEntropyLoss,
|
14
|
-
)
|
3
|
+
|
4
|
+
from utils import QUANTILES
|
5
|
+
from utils import SingleBenchmarkRunInput
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
7
|
+
from utils import _test_memory
|
8
|
+
from utils import parse_benchmark_script_args
|
9
|
+
from utils import run_benchmarks
|
10
|
+
|
11
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
15
12
|
from liger_kernel.utils import infer_device
|
16
13
|
|
17
14
|
device = infer_device()
|
@@ -28,12 +25,8 @@ class TorchLMHeadCE(torch.nn.Module):
|
|
28
25
|
|
29
26
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
30
27
|
super().__init__()
|
31
|
-
self.lin = torch.nn.Linear(
|
32
|
-
|
33
|
-
)
|
34
|
-
self.ce_loss = torch.nn.CrossEntropyLoss(
|
35
|
-
ignore_index=ignore_index, reduction="mean"
|
36
|
-
)
|
28
|
+
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
|
29
|
+
self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction="mean")
|
37
30
|
|
38
31
|
def forward(self, x, y):
|
39
32
|
logits = self.lin(x)
|
@@ -43,12 +36,8 @@ class TorchLMHeadCE(torch.nn.Module):
|
|
43
36
|
class LigerLMHeadCE(torch.nn.Module):
|
44
37
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
45
38
|
super().__init__()
|
46
|
-
self.lin = torch.nn.Linear(
|
47
|
-
|
48
|
-
)
|
49
|
-
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
|
50
|
-
ignore_index=ignore_index, reduction="mean"
|
51
|
-
)
|
39
|
+
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
|
40
|
+
self.ce_loss = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction="mean")
|
52
41
|
|
53
42
|
def forward(self, x, y):
|
54
43
|
return self.ce_loss(self.lin.weight, x, y)
|
@@ -161,9 +150,7 @@ if __name__ == "__main__":
|
|
161
150
|
"x_label": "B x T",
|
162
151
|
"x_values": [2**i for i in range(12, 16)],
|
163
152
|
"kernel_providers": ["liger", "huggingface"],
|
164
|
-
"extra_benchmark_configs": [
|
165
|
-
{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
|
166
|
-
],
|
153
|
+
"extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
|
167
154
|
"overwrite": args.overwrite,
|
168
155
|
}
|
169
156
|
|
@@ -172,12 +159,12 @@ if __name__ == "__main__":
|
|
172
159
|
kernel_operation_modes=["forward", "full"],
|
173
160
|
metric_name="speed",
|
174
161
|
metric_unit="ms",
|
175
|
-
**common_configs
|
162
|
+
**common_configs,
|
176
163
|
)
|
177
164
|
run_benchmarks(
|
178
165
|
bench_test_fn=bench_memory_fused_linear_cross_entropy,
|
179
166
|
kernel_operation_modes=["full"],
|
180
167
|
metric_name="memory",
|
181
168
|
metric_unit="MB",
|
182
|
-
**common_configs
|
169
|
+
**common_configs,
|
183
170
|
)
|
@@ -1,13 +1,12 @@
|
|
1
1
|
import torch
|
2
2
|
import triton
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
)
|
3
|
+
|
4
|
+
from utils import QUANTILES
|
5
|
+
from utils import SingleBenchmarkRunInput
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
7
|
+
from utils import _test_memory
|
8
|
+
from utils import parse_benchmark_script_args
|
9
|
+
from utils import run_benchmarks
|
11
10
|
|
12
11
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD
|
13
12
|
from liger_kernel.utils import infer_device
|
@@ -37,9 +36,9 @@ class TorchJSD(torch.nn.Module):
|
|
37
36
|
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
|
38
37
|
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
|
39
38
|
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
|
40
|
-
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
|
41
|
-
|
42
|
-
)
|
39
|
+
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl(
|
40
|
+
torch.log(m), log_q
|
41
|
+
).sum(dim=-1)
|
43
42
|
|
44
43
|
if label is not None:
|
45
44
|
loss = torch.where(label != self.ignore_index, loss, 0.0)
|
@@ -73,12 +72,8 @@ class TorchLMHeadJSD(torch.nn.Module):
|
|
73
72
|
temperature: float = 1.0,
|
74
73
|
):
|
75
74
|
super().__init__()
|
76
|
-
self.student_lin = torch.nn.Linear(
|
77
|
-
|
78
|
-
)
|
79
|
-
self.teacher_lin = torch.nn.Linear(
|
80
|
-
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
|
81
|
-
)
|
75
|
+
self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
|
76
|
+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
|
82
77
|
self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype)
|
83
78
|
self.temperature = temperature
|
84
79
|
|
@@ -103,15 +98,9 @@ class LigerLMHeadJSD(torch.nn.Module):
|
|
103
98
|
temperature: float = 1.0,
|
104
99
|
):
|
105
100
|
super().__init__()
|
106
|
-
self.student_lin = torch.nn.Linear(
|
107
|
-
|
108
|
-
)
|
109
|
-
self.teacher_lin = torch.nn.Linear(
|
110
|
-
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
|
111
|
-
)
|
112
|
-
self.fused_jsd = LigerFusedLinearJSD(
|
113
|
-
jsd_beta=beta, ignore_index=ignore_index, temperature=temperature
|
114
|
-
)
|
101
|
+
self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
|
102
|
+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
|
103
|
+
self.fused_jsd = LigerFusedLinearJSD(jsd_beta=beta, ignore_index=ignore_index, temperature=temperature)
|
115
104
|
|
116
105
|
def forward(self, student_input, teacher_input, label=None):
|
117
106
|
return self.fused_jsd(
|
@@ -141,12 +130,12 @@ def bench_memory_fused_linear_jsd(
|
|
141
130
|
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
|
142
131
|
|
143
132
|
# init the linear in all FusedLinearJSDs with the same weights
|
144
|
-
torch_lm_head_jsd.student_lin.weight.data = (
|
145
|
-
|
146
|
-
)
|
147
|
-
torch_lm_head_jsd.teacher_lin.weight.data = (
|
148
|
-
|
149
|
-
)
|
133
|
+
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
|
134
|
+
V, H, device=device, dtype=dtype
|
135
|
+
)
|
136
|
+
torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
|
137
|
+
V, H, device=device, dtype=dtype
|
138
|
+
)
|
150
139
|
|
151
140
|
student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
|
152
141
|
teacher_input = torch.rand(BT, H, dtype=dtype, device=device)
|
@@ -189,12 +178,12 @@ def bench_speed_fused_linear_jsd(
|
|
189
178
|
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
|
190
179
|
|
191
180
|
# init the linear in all FusedLinearJSDs with the same weights
|
192
|
-
torch_lm_head_jsd.student_lin.weight.data = (
|
193
|
-
|
194
|
-
)
|
195
|
-
torch_lm_head_jsd.teacher_lin.weight.data = (
|
196
|
-
|
197
|
-
)
|
181
|
+
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
|
182
|
+
V, H, device=device, dtype=dtype
|
183
|
+
)
|
184
|
+
torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
|
185
|
+
V, H, device=device, dtype=dtype
|
186
|
+
)
|
198
187
|
|
199
188
|
student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
|
200
189
|
teacher_input = torch.rand(BT, H, dtype=dtype, device=device)
|
@@ -251,9 +240,7 @@ if __name__ == "__main__":
|
|
251
240
|
"x_label": "B x T",
|
252
241
|
"x_values": [2**i for i in range(10, 14)],
|
253
242
|
"kernel_providers": ["liger", "torch"],
|
254
|
-
"extra_benchmark_configs": [
|
255
|
-
{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
|
256
|
-
],
|
243
|
+
"extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
|
257
244
|
"overwrite": args.overwrite,
|
258
245
|
}
|
259
246
|
|
@@ -262,12 +249,12 @@ if __name__ == "__main__":
|
|
262
249
|
kernel_operation_modes=["forward", "full"],
|
263
250
|
metric_name="speed",
|
264
251
|
metric_unit="ms",
|
265
|
-
**common_configs
|
252
|
+
**common_configs,
|
266
253
|
)
|
267
254
|
run_benchmarks(
|
268
255
|
bench_test_fn=bench_memory_fused_linear_jsd,
|
269
256
|
kernel_operation_modes=["full"],
|
270
257
|
metric_name="memory",
|
271
258
|
metric_unit="MB",
|
272
|
-
**common_configs
|
259
|
+
**common_configs,
|
273
260
|
)
|
@@ -1,15 +1,14 @@
|
|
1
1
|
import torch
|
2
2
|
import triton
|
3
|
+
|
3
4
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
4
5
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
5
|
-
from utils import
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
run_benchmarks,
|
12
|
-
)
|
6
|
+
from utils import QUANTILES
|
7
|
+
from utils import SingleBenchmarkRunInput
|
8
|
+
from utils import SingleBenchmarkRunOutput
|
9
|
+
from utils import _test_memory
|
10
|
+
from utils import parse_benchmark_script_args
|
11
|
+
from utils import run_benchmarks
|
13
12
|
|
14
13
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
15
14
|
from liger_kernel.utils import infer_device
|
@@ -1,13 +1,12 @@
|
|
1
1
|
import torch
|
2
2
|
import triton
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
)
|
3
|
+
|
4
|
+
from utils import QUANTILES
|
5
|
+
from utils import SingleBenchmarkRunInput
|
6
|
+
from utils import SingleBenchmarkRunOutput
|
7
|
+
from utils import _test_memory
|
8
|
+
from utils import parse_benchmark_script_args
|
9
|
+
from utils import run_benchmarks
|
11
10
|
|
12
11
|
from liger_kernel.transformers.group_norm import LigerGroupNorm
|
13
12
|
from liger_kernel.utils import infer_device
|
@@ -27,12 +26,8 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
|
|
27
26
|
dtype = extra_benchmark_config["dtype"]
|
28
27
|
|
29
28
|
x_shape = (M, C, H)
|
30
|
-
triton_ln = LigerGroupNorm(
|
31
|
-
|
32
|
-
).to(device)
|
33
|
-
torch_ln = torch.nn.GroupNorm(
|
34
|
-
num_groups=C // channels_per_group, num_channels=C, eps=eps
|
35
|
-
).to(device)
|
29
|
+
triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
|
30
|
+
torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
|
36
31
|
|
37
32
|
x = torch.randn(x_shape, dtype=dtype, device=device)
|
38
33
|
dy = torch.randn_like(x)
|
@@ -45,9 +40,7 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
|
|
45
40
|
return torch_ln(x)
|
46
41
|
|
47
42
|
if mode == "forward":
|
48
|
-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
49
|
-
y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
|
50
|
-
)
|
43
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
|
51
44
|
elif mode == "backward":
|
52
45
|
y = y_fwd()
|
53
46
|
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
@@ -62,9 +55,7 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
|
|
62
55
|
y = y_fwd()
|
63
56
|
y.backward(dy, retain_graph=True)
|
64
57
|
|
65
|
-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
|
66
|
-
full, quantiles=QUANTILES, grad_to_none=[x], rep=500
|
67
|
-
)
|
58
|
+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
|
68
59
|
|
69
60
|
return SingleBenchmarkRunOutput(
|
70
61
|
y_20=ms_20,
|
@@ -84,12 +75,8 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu
|
|
84
75
|
dtype = extra_benchmark_config["dtype"]
|
85
76
|
|
86
77
|
x_shape = (M, C, H)
|
87
|
-
triton_ln = LigerGroupNorm(
|
88
|
-
|
89
|
-
).to(device)
|
90
|
-
torch_ln = torch.nn.GroupNorm(
|
91
|
-
num_groups=C // channels_per_group, num_channels=C, eps=eps
|
92
|
-
).to(device)
|
78
|
+
triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
|
79
|
+
torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
|
93
80
|
|
94
81
|
x = torch.randn(x_shape, dtype=dtype, device=device)
|
95
82
|
dy = torch.randn_like(x)
|
@@ -139,12 +126,12 @@ if __name__ == "__main__":
|
|
139
126
|
kernel_operation_modes=["forward", "full", "backward"],
|
140
127
|
metric_name="speed",
|
141
128
|
metric_unit="ms",
|
142
|
-
**common_configs
|
129
|
+
**common_configs,
|
143
130
|
)
|
144
131
|
run_benchmarks(
|
145
132
|
bench_test_fn=bench_memory_group_norm,
|
146
133
|
kernel_operation_modes=["full", "forward", "backward"],
|
147
134
|
metric_name="memory",
|
148
135
|
metric_unit="MB",
|
149
|
-
**common_configs
|
136
|
+
**common_configs,
|
150
137
|
)
|