liger-kernel-nightly 0.5.4.dev20250312154502__tar.gz → 0.5.5.dev20250314203927__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +2 -2
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/jsd_loss.py +12 -7
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/jsd.py +30 -11
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_jsd_loss.py +15 -10
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_jsd.py +3 -3
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/utils.py +3 -2
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/.gitignore +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/Makefile +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/index.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/docs/license.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/setup.cfg +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/setup.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/olmo2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/test/triton/test_triton_monkey_patch.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.5.
|
|
7
|
+
version = "0.5.5.dev20250314203927"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -117,7 +117,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
117
117
|
|
|
118
118
|
hard_loss /= full_target.shape[0]
|
|
119
119
|
|
|
120
|
-
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
|
|
120
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
|
|
121
121
|
soft_loss /= full_target.shape[0]
|
|
122
122
|
|
|
123
123
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
@@ -180,9 +180,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
180
|
ignore_index=ignore_index,
|
|
181
181
|
weight_hard_loss=weight_hard_loss,
|
|
182
182
|
weight_soft_loss=weight_soft_loss,
|
|
183
|
-
beta=beta,
|
|
184
183
|
compute_ce_loss=compute_ce_loss,
|
|
185
184
|
temperature=temperature,
|
|
185
|
+
beta=beta,
|
|
186
186
|
**loss_kwargs,
|
|
187
187
|
)
|
|
188
188
|
|
|
@@ -19,15 +19,20 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
19
19
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
20
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
if beta == 0:
|
|
23
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
24
|
+
elif beta == 1:
|
|
25
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
|
+
else:
|
|
27
|
+
# Compute probabilities (only required for mean calculation)
|
|
28
|
+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
|
|
29
|
+
log_mean_probs = mean_probs.log()
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
|
|
31
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
28
33
|
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
# JSD is the weighted average of the KL divergences
|
|
35
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
36
|
return jsd_loss
|
|
32
37
|
|
|
33
38
|
@classmethod
|
|
@@ -51,24 +51,43 @@ def _jsd_kernel(
|
|
|
51
51
|
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
52
|
|
|
53
53
|
if beta == 0.0: # forward KL
|
|
54
|
-
|
|
54
|
+
Y_max = tl.max(Y, axis=0)
|
|
55
|
+
Y_shifted = Y - Y_max
|
|
56
|
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
|
55
57
|
loss = Y_prob * (Y - X)
|
|
56
58
|
dX = -Y_prob
|
|
57
|
-
elif beta == 1.0:
|
|
58
|
-
|
|
59
|
+
elif beta == 1.0: # reverse KL
|
|
60
|
+
X_max = tl.max(X, axis=0)
|
|
61
|
+
X_shifted = X - X_max
|
|
62
|
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
|
59
63
|
loss = X_prob * (X - Y)
|
|
60
64
|
dX = loss + X_prob
|
|
61
65
|
else:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
log_M = tl.log(M)
|
|
66
|
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
|
67
|
+
X_shifted = X - max_val
|
|
68
|
+
Y_shifted = Y - max_val
|
|
66
69
|
|
|
67
|
-
|
|
68
|
-
|
|
70
|
+
# Pre-compute exp(max_val) since it's used twice
|
|
71
|
+
exp_max = tl.exp(max_val)
|
|
72
|
+
|
|
73
|
+
# Compute exp terms with compensation
|
|
74
|
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
|
75
|
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
|
76
|
+
|
|
77
|
+
# Pre-compute common terms
|
|
78
|
+
beta_P = beta * P
|
|
79
|
+
one_minus_beta_Q = (1 - beta) * Q
|
|
80
|
+
M = beta_P + one_minus_beta_Q
|
|
81
|
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
|
82
|
+
|
|
83
|
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
|
84
|
+
dX = one_minus_beta_Q * (X - log_M)
|
|
85
|
+
|
|
86
|
+
# Pre-compute scaling factor
|
|
87
|
+
scale = 1.0 / n_non_ignore
|
|
88
|
+
loss = loss * scale
|
|
89
|
+
dX = dX * scale
|
|
69
90
|
|
|
70
|
-
loss = loss / n_non_ignore
|
|
71
|
-
dX = dX / n_non_ignore
|
|
72
91
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
73
92
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
74
93
|
|
|
@@ -27,7 +27,6 @@ class HFJSDLoss(HFDistillationLoss):
|
|
|
27
27
|
ignore_index: int = -100,
|
|
28
28
|
weight_hard_loss: float = 0.5,
|
|
29
29
|
weight_soft_loss: float = 0.5,
|
|
30
|
-
beta: float = 0.5,
|
|
31
30
|
):
|
|
32
31
|
super().__init__(
|
|
33
32
|
ignore_index=ignore_index,
|
|
@@ -35,7 +34,6 @@ class HFJSDLoss(HFDistillationLoss):
|
|
|
35
34
|
weight_soft_loss=weight_soft_loss,
|
|
36
35
|
temperature=temperature,
|
|
37
36
|
)
|
|
38
|
-
self.beta = (beta,)
|
|
39
37
|
|
|
40
38
|
def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
|
|
41
39
|
"""
|
|
@@ -50,15 +48,20 @@ class HFJSDLoss(HFDistillationLoss):
|
|
|
50
48
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
51
49
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
52
50
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
51
|
+
if beta == 0:
|
|
52
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
|
53
|
+
elif beta == 1:
|
|
54
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
|
55
|
+
else:
|
|
56
|
+
# Compute probabilities (only required for mean calculation)
|
|
57
|
+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
|
|
58
|
+
log_mean_probs = mean_probs.log()
|
|
56
59
|
|
|
57
|
-
|
|
58
|
-
|
|
60
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
|
|
61
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)
|
|
59
62
|
|
|
60
|
-
|
|
61
|
-
|
|
63
|
+
# JSD is the weighted average of the KL divergences
|
|
64
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
62
65
|
return jsd_loss
|
|
63
66
|
|
|
64
67
|
|
|
@@ -88,12 +91,12 @@ class TorchLMHeadJSD(torch.nn.Module):
|
|
|
88
91
|
# smaller student model weights
|
|
89
92
|
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device)
|
|
90
93
|
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device)
|
|
94
|
+
self.beta = beta
|
|
91
95
|
self.jsd = HFJSDLoss(
|
|
92
96
|
ignore_index=ignore_index,
|
|
93
97
|
weight_hard_loss=weight_hard_loss,
|
|
94
98
|
weight_soft_loss=weight_soft_loss,
|
|
95
99
|
temperature=temperature,
|
|
96
|
-
beta=beta,
|
|
97
100
|
).get_batch_loss_metrics
|
|
98
101
|
|
|
99
102
|
def forward(self, student_input, teacher_input, target):
|
|
@@ -105,6 +108,7 @@ class TorchLMHeadJSD(torch.nn.Module):
|
|
|
105
108
|
target,
|
|
106
109
|
self.student_lin.bias,
|
|
107
110
|
self.teacher_lin.bias,
|
|
111
|
+
beta=self.beta,
|
|
108
112
|
)
|
|
109
113
|
return jsd_loss
|
|
110
114
|
|
|
@@ -132,6 +136,7 @@ class LigerLMHeadJSD(torch.nn.Module):
|
|
|
132
136
|
weight_soft_loss=weight_soft_loss,
|
|
133
137
|
ignore_index=ignore_index,
|
|
134
138
|
temperature=temperature,
|
|
139
|
+
beta=beta,
|
|
135
140
|
)
|
|
136
141
|
|
|
137
142
|
def forward(self, student_input, teacher_input, target):
|
|
@@ -33,13 +33,13 @@ class JSD(torch.nn.Module):
|
|
|
33
33
|
|
|
34
34
|
def forward(
|
|
35
35
|
self,
|
|
36
|
-
log_q: torch.Tensor, # input
|
|
36
|
+
log_q: torch.Tensor, # input student logits
|
|
37
37
|
log_p: torch.Tensor, # target
|
|
38
38
|
label: Optional[torch.Tensor] = None,
|
|
39
39
|
):
|
|
40
|
-
if self.beta == 0.0:
|
|
40
|
+
if self.beta == 0.0: # KL(p||q) -> kl(q, p)
|
|
41
41
|
loss = self.kl(log_q, log_p).sum(dim=-1)
|
|
42
|
-
elif self.beta == 1.0:
|
|
42
|
+
elif self.beta == 1.0: # KL(q||p) -> kl(p, q)
|
|
43
43
|
loss = self.kl(log_p, log_q).sum(dim=-1)
|
|
44
44
|
else:
|
|
45
45
|
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
|
|
@@ -593,7 +593,7 @@ class HFDistillationLoss:
|
|
|
593
593
|
self.temperature = temperature
|
|
594
594
|
|
|
595
595
|
@abstractmethod
|
|
596
|
-
def distillation_loss(self, student_logits, teacher_logits):
|
|
596
|
+
def distillation_loss(self, student_logits, teacher_logits, **loss_kwargs):
|
|
597
597
|
"""Abstract method for computing distillation loss."""
|
|
598
598
|
pass
|
|
599
599
|
|
|
@@ -666,6 +666,7 @@ class HFDistillationLoss:
|
|
|
666
666
|
target: torch.LongTensor,
|
|
667
667
|
student_bias: torch.FloatTensor = None,
|
|
668
668
|
teacher_bias: torch.FloatTensor = None,
|
|
669
|
+
**loss_kwargs,
|
|
669
670
|
):
|
|
670
671
|
"""Compute the distillation loss metrics for the given batch."""
|
|
671
672
|
forward_output = self.concatenated_forward(
|
|
@@ -686,7 +687,7 @@ class HFDistillationLoss:
|
|
|
686
687
|
student_logits /= self.temperature
|
|
687
688
|
teacher_logits /= self.temperature
|
|
688
689
|
|
|
689
|
-
soft_loss = self.distillation_loss(student_logits, teacher_logits)
|
|
690
|
+
soft_loss = self.distillation_loss(student_logits, teacher_logits, **loss_kwargs)
|
|
690
691
|
# full loss
|
|
691
692
|
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss
|
|
692
693
|
return loss
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{liger_kernel_nightly-0.5.4.dev20250312154502 → liger_kernel_nightly-0.5.5.dev20250314203927}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|