liger-kernel-nightly 0.6.3.dev20251105224413__tar.gz → 0.6.3.dev20251106220336__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/cross_entropy.py +59 -9
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/fused_linear_cross_entropy.py +27 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/cross_entropy.py +8 -3
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/functional.py +24 -6
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/falcon_h1.py +19 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/gemma.py +17 -6
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/gemma2.py +14 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/gemma3.py +25 -12
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/glm4.py +16 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/glm4v.py +16 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/glm4v_moe.py +23 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/internvl.py +12 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/llama.py +14 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/llama4.py +16 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/llava.py +12 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/loss_utils.py +31 -3
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/mistral.py +15 -6
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/mixtral.py +16 -7
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/mllama.py +12 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel_nightly-0.6.3.dev20251106220336/src/liger_kernel/transformers/model/output_classes.py +147 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/paligemma.py +22 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/phi3.py +14 -7
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen2.py +16 -3
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen2_vl.py +16 -4
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen3.py +18 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen3_moe.py +19 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen3_next.py +17 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen3_vl.py +11 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/qwen3_vl_moe.py +12 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/smollm3.py +15 -6
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_cross_entropy.py +81 -6
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_fused_linear_cross_entropy.py +229 -5
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/pull_request_template.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/amd-ci.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/benchmark.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/docs.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/intel-ci.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/nvi-ci.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/publish-nightly.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.github/workflows/publish-release.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/.gitignore +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/Makefile +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/README.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/README.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/benchmarks_visualizer.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/data/all_benchmark_data.csv +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_dyt.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_embedding.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_geglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_group_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_kl_div.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_softmax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_swiglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/benchmark_tvd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/benchmark/scripts/utils.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/dev/fmt-requirements.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/dev/modal/benchmarks.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/dev/modal/tests.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/dev/modal/tests_bwd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/Examples.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/Getting-Started.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/High-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/Low-Level-APIs.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/acknowledgement.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/contributing.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/banner.GIF +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/compose.gif +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/e2e-memory.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/e2e-tps.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/logo-banner.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/patch.gif +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/images/post-training.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/index.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/docs/license.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/alignment/accelerate_config.yaml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/alignment/run_orpo.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/README.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/callback.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/config/fsdp_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/img/gemma_7b_mem.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/img/gemma_7b_tp.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/img/llama_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/img/llama_tps.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/img/qwen_tps.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/launch_on_modal.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/run_benchmarks.sh +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/run_gemma.sh +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/run_llama.sh +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/run_qwen.sh +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/run_qwen2_vl.sh +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/training.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/huggingface/training_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/lightning/README.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/lightning/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/lightning/training.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/README.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/callback.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/medusa_util.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/requirements.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/examples/medusa/train.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/licenses/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/licenses/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/licenses/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/licenses/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/mkdocs.yml +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/setup.cfg +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/setup.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/README.md +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/dyt.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/softmax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/tvd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/dyt.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/fsdp.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/llama4_rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/model/smolvlm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/softmax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/transformers/tvd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_cosine_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_cpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_dpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_jsd_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_kto_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_orpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/chunked_loss/test_simpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/conftest.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/bf16/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/bf16/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/fp32/__init__.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/fp32/test_mini_models.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/Qwen/Qwen3-VL-4B-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/tiny_shakespeare.txt +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_auto_model.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_dyt.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_embedding.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_flex_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_fused_add_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_fused_neighborhood_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_geglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_group_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_grpo_loss.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_jsd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_kl_div.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_layer_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_multi_token_attention.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_poly_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_rms_norm.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_rope.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_softmax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_sparsemax.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_swiglu.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_trainer_integration.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_transformers.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/transformers/test_tvd.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/triton/test_triton_monkey_patch.py +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105224413 → liger_kernel_nightly-0.6.3.dev20251106220336}/test/utils.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.6.3.
|
|
7
|
+
version = "0.6.3.dev20251106220336"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
32
|
loss_ptr,
|
|
33
33
|
z_loss_ptr,
|
|
34
34
|
loss_stride,
|
|
35
|
+
token_accuracy_ptr,
|
|
36
|
+
token_accuracy_stride,
|
|
35
37
|
n_cols,
|
|
36
38
|
n_non_ignore,
|
|
37
39
|
sum_non_ignore_weight,
|
|
@@ -42,6 +44,7 @@ def liger_cross_entropy_kernel(
|
|
|
42
44
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
45
|
softcap,
|
|
44
46
|
RETURN_Z_LOSS: tl.constexpr,
|
|
47
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
48
|
BLOCK_SIZE: tl.constexpr,
|
|
46
49
|
HAS_WEIGHT: tl.constexpr,
|
|
47
50
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
@@ -60,6 +63,8 @@ def liger_cross_entropy_kernel(
|
|
|
60
63
|
loss_ptr: Pointer to tensor to store the loss.
|
|
61
64
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
62
65
|
loss_stride (int): The stride of the loss tensor.
|
|
66
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
67
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
63
68
|
n_cols (int): The number of columns in the input tensor.
|
|
64
69
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
65
70
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -69,7 +74,8 @@ def liger_cross_entropy_kernel(
|
|
|
69
74
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
70
75
|
reduction (str): The string for the reduction to apply
|
|
71
76
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
72
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
77
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
78
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
73
79
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
74
80
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
75
81
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
@@ -92,11 +98,17 @@ def liger_cross_entropy_kernel(
|
|
|
92
98
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
93
99
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
94
100
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
101
|
+
# For ignored tokens, set token accuracy to 0
|
|
102
|
+
if RETURN_TOKEN_ACCURACY:
|
|
103
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
104
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
95
105
|
return
|
|
96
106
|
|
|
97
107
|
loss_ptr += program_id * loss_stride
|
|
98
108
|
if RETURN_Z_LOSS:
|
|
99
109
|
z_loss_ptr += program_id * loss_stride
|
|
110
|
+
if RETURN_TOKEN_ACCURACY:
|
|
111
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
100
112
|
|
|
101
113
|
if HAS_WEIGHT:
|
|
102
114
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -107,6 +119,7 @@ def liger_cross_entropy_kernel(
|
|
|
107
119
|
# 3. [Online softmax] first pass: find max + sum
|
|
108
120
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
109
121
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
122
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
110
123
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
111
124
|
if HAS_SOFTCAPPING:
|
|
112
125
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -127,6 +140,16 @@ def liger_cross_entropy_kernel(
|
|
|
127
140
|
if HAS_SOFTCAPPING:
|
|
128
141
|
X_block = softcap * tanh(X_block / softcap)
|
|
129
142
|
block_max = tl.max(X_block)
|
|
143
|
+
|
|
144
|
+
# Track argmax for accuracy computation
|
|
145
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
146
|
+
# Find the index of the maximum value in this block
|
|
147
|
+
is_max_mask = X_block == block_max
|
|
148
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
149
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
150
|
+
# Get the first (smallest) index where max occurs
|
|
151
|
+
argmax_idx = tl.min(masked_offsets)
|
|
152
|
+
|
|
130
153
|
if label_smoothing > 0:
|
|
131
154
|
# scale X beforehand to avoid overflow
|
|
132
155
|
if HAS_WEIGHT:
|
|
@@ -256,6 +279,10 @@ def liger_cross_entropy_kernel(
|
|
|
256
279
|
tl.store(loss_ptr, loss)
|
|
257
280
|
if RETURN_Z_LOSS:
|
|
258
281
|
tl.store(z_loss_ptr, z_loss)
|
|
282
|
+
if RETURN_TOKEN_ACCURACY:
|
|
283
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
284
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
285
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
259
286
|
|
|
260
287
|
|
|
261
288
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -274,8 +301,12 @@ def cross_entropy_forward(
|
|
|
274
301
|
reduction,
|
|
275
302
|
softcap,
|
|
276
303
|
return_z_loss,
|
|
304
|
+
return_token_accuracy=False,
|
|
277
305
|
):
|
|
278
306
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
307
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
308
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
309
|
+
)
|
|
279
310
|
|
|
280
311
|
BT, V = _input.shape
|
|
281
312
|
n_rows = BT
|
|
@@ -285,6 +316,9 @@ def cross_entropy_forward(
|
|
|
285
316
|
# unreduced loss
|
|
286
317
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
287
318
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
319
|
+
token_accuracy_1d = (
|
|
320
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
321
|
+
)
|
|
288
322
|
|
|
289
323
|
target_mask = target != ignore_index
|
|
290
324
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -321,6 +355,10 @@ def cross_entropy_forward(
|
|
|
321
355
|
loss_ptr=loss_1d,
|
|
322
356
|
z_loss_ptr=z_loss_1d,
|
|
323
357
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
358
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
359
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
360
|
+
if return_token_accuracy
|
|
361
|
+
else 0, # always 1 if accuracy is enabled
|
|
324
362
|
n_cols=V,
|
|
325
363
|
n_non_ignore=n_non_ignore,
|
|
326
364
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -331,6 +369,7 @@ def cross_entropy_forward(
|
|
|
331
369
|
reduction=reduction,
|
|
332
370
|
softcap=softcap,
|
|
333
371
|
RETURN_Z_LOSS=return_z_loss,
|
|
372
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
334
373
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
335
374
|
HAS_WEIGHT=True if weight is not None else False,
|
|
336
375
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
@@ -343,11 +382,14 @@ def cross_entropy_forward(
|
|
|
343
382
|
if reduction == "none":
|
|
344
383
|
loss = loss_1d
|
|
345
384
|
z_loss = z_loss_1d if return_z_loss else None
|
|
385
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
346
386
|
else:
|
|
347
387
|
loss = torch.sum(loss_1d)
|
|
348
388
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
389
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
390
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
349
391
|
|
|
350
|
-
return loss, z_loss, _input
|
|
392
|
+
return loss, z_loss, token_accuracy, _input
|
|
351
393
|
|
|
352
394
|
|
|
353
395
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -395,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
395
437
|
reduction: str = "mean",
|
|
396
438
|
softcap: Optional[float] = None,
|
|
397
439
|
return_z_loss: bool = False,
|
|
440
|
+
return_token_accuracy: bool = False,
|
|
398
441
|
):
|
|
399
442
|
"""
|
|
400
443
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -409,14 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
409
452
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
410
453
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
411
454
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
412
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
455
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
456
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
413
457
|
|
|
414
458
|
Returns:
|
|
415
|
-
tuple: A tuple with the
|
|
459
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
416
460
|
"""
|
|
417
461
|
input_requires_grad = _input.requires_grad
|
|
418
462
|
|
|
419
|
-
loss, z_loss, _input = cross_entropy_forward(
|
|
463
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
420
464
|
_input,
|
|
421
465
|
target,
|
|
422
466
|
weight,
|
|
@@ -426,6 +470,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
426
470
|
reduction,
|
|
427
471
|
softcap,
|
|
428
472
|
return_z_loss,
|
|
473
|
+
return_token_accuracy,
|
|
429
474
|
)
|
|
430
475
|
# TODO: investigation
|
|
431
476
|
# If we don't detach the _input tensor, the memory will double
|
|
@@ -433,23 +478,27 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
433
478
|
if input_requires_grad:
|
|
434
479
|
ctx.save_for_backward(_input.detach())
|
|
435
480
|
ctx.return_z_loss = return_z_loss
|
|
481
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
436
482
|
|
|
437
|
-
return loss, z_loss
|
|
483
|
+
return loss, z_loss, token_accuracy
|
|
438
484
|
|
|
439
485
|
@staticmethod
|
|
440
|
-
def backward(ctx, grad_output,
|
|
486
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
441
487
|
"""
|
|
442
488
|
The backward pass of the Liger Cross Entropy loss.
|
|
443
489
|
|
|
444
490
|
Parameters:
|
|
445
491
|
ctx : The context object with saved tensors.
|
|
446
492
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
447
|
-
grad_output2 (
|
|
493
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
494
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
448
495
|
Returns:
|
|
449
496
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
450
497
|
"""
|
|
451
498
|
if ctx.return_z_loss:
|
|
452
|
-
del
|
|
499
|
+
del grad_output2 # z_loss is only for logging
|
|
500
|
+
if ctx.return_token_accuracy:
|
|
501
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
453
502
|
|
|
454
503
|
(_input,) = ctx.saved_tensors
|
|
455
504
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -463,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
463
512
|
None,
|
|
464
513
|
None,
|
|
465
514
|
None,
|
|
515
|
+
None,
|
|
466
516
|
)
|
|
@@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward(
|
|
|
27
27
|
return_z_loss=False,
|
|
28
28
|
accum_dtype=None,
|
|
29
29
|
use_token_scaling=False,
|
|
30
|
+
return_token_accuracy=False,
|
|
30
31
|
):
|
|
31
32
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
33
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
34
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
35
|
+
)
|
|
32
36
|
device = _input.device
|
|
33
37
|
|
|
34
38
|
input_requires_grad = _input.requires_grad
|
|
@@ -64,6 +68,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
64
68
|
|
|
65
69
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
66
70
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
71
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
67
72
|
|
|
68
73
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
69
74
|
target_mask = target != ignore_index
|
|
@@ -129,6 +134,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
129
134
|
# unreduced loss
|
|
130
135
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
131
136
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
137
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
132
138
|
|
|
133
139
|
# ensure _input and target are contiguous
|
|
134
140
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -144,6 +150,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
144
150
|
loss_ptr=loss_1d_slice,
|
|
145
151
|
z_loss_ptr=z_loss_1d_slice,
|
|
146
152
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
153
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
154
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
155
|
+
if return_token_accuracy
|
|
156
|
+
else 0, # always 1 if accuracy is enabled
|
|
147
157
|
n_cols=V,
|
|
148
158
|
n_non_ignore=total_n_non_ignore,
|
|
149
159
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -154,6 +164,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
154
164
|
reduction=reduction,
|
|
155
165
|
softcap=softcap,
|
|
156
166
|
RETURN_Z_LOSS=return_z_loss,
|
|
167
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
157
168
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
158
169
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
159
170
|
HAS_GRADIENTS=input_requires_grad,
|
|
@@ -170,6 +181,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
170
181
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
171
182
|
if return_z_loss:
|
|
172
183
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
184
|
+
if return_token_accuracy:
|
|
185
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
173
186
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
174
187
|
|
|
175
188
|
# Apply token scaling to gradients if requested
|
|
@@ -201,15 +214,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
201
214
|
# Return per-token losses
|
|
202
215
|
loss = loss_1d
|
|
203
216
|
z_loss = z_loss_1d if return_z_loss else None
|
|
217
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
204
218
|
else:
|
|
205
219
|
loss = torch.sum(loss_1d)
|
|
206
220
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
221
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
222
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
207
223
|
|
|
208
224
|
# Cast back to original dtype
|
|
209
225
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
210
226
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
211
227
|
|
|
212
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
228
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
213
229
|
|
|
214
230
|
|
|
215
231
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -277,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
277
293
|
return_z_loss: bool = False,
|
|
278
294
|
accum_dtype=None,
|
|
279
295
|
use_token_scaling: bool = False,
|
|
296
|
+
return_token_accuracy: bool = False,
|
|
280
297
|
):
|
|
281
298
|
"""
|
|
282
299
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -300,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
300
317
|
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
301
318
|
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
302
319
|
Default: False.
|
|
320
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
303
321
|
"""
|
|
304
322
|
|
|
305
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
323
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
306
324
|
_input=_input,
|
|
307
325
|
weight=weight,
|
|
308
326
|
target=target,
|
|
@@ -316,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
316
334
|
return_z_loss=return_z_loss,
|
|
317
335
|
accum_dtype=accum_dtype,
|
|
318
336
|
use_token_scaling=use_token_scaling,
|
|
337
|
+
return_token_accuracy=return_token_accuracy,
|
|
319
338
|
)
|
|
320
339
|
# downcast to dtype and store for backward
|
|
321
340
|
ctx.save_for_backward(
|
|
@@ -324,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
324
343
|
grad_bias.detach() if bias is not None else None,
|
|
325
344
|
)
|
|
326
345
|
ctx.return_z_loss = return_z_loss
|
|
327
|
-
|
|
346
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
347
|
+
return loss, z_loss, token_accuracy
|
|
328
348
|
|
|
329
349
|
@staticmethod
|
|
330
350
|
@amp_custom_bwd
|
|
331
|
-
def backward(ctx, grad_output, grad_output2):
|
|
351
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
332
352
|
if ctx.return_z_loss:
|
|
333
353
|
del grad_output2 # z_loss is only for logging
|
|
354
|
+
if ctx.return_token_accuracy:
|
|
355
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
334
356
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
335
357
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
336
358
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -349,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
349
371
|
None,
|
|
350
372
|
None,
|
|
351
373
|
None, # use_token_scaling
|
|
374
|
+
None, # return_token_accuracy
|
|
352
375
|
)
|
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
6
|
+
from liger_kernel.transformers.functional import CrossEntropyOutput
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LigerCrossEntropyLoss(torch.nn.Module):
|
|
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
15
16
|
reduction: str = "mean",
|
|
16
17
|
softcap: Optional[float] = None,
|
|
17
18
|
return_z_loss: bool = False,
|
|
19
|
+
return_token_accuracy: bool = False,
|
|
18
20
|
):
|
|
19
21
|
super().__init__()
|
|
20
22
|
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
|
|
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
33
35
|
self.reduction = reduction
|
|
34
36
|
self.softcap = softcap
|
|
35
37
|
self.return_z_loss = return_z_loss
|
|
38
|
+
self.return_token_accuracy = return_token_accuracy
|
|
36
39
|
|
|
37
40
|
def forward(self, _input: torch.Tensor, target: torch.Tensor):
|
|
38
|
-
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
41
|
+
loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
|
|
39
42
|
_input,
|
|
40
43
|
target,
|
|
41
44
|
self.weight,
|
|
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
45
48
|
self.reduction,
|
|
46
49
|
self.softcap,
|
|
47
50
|
self.return_z_loss,
|
|
51
|
+
self.return_token_accuracy,
|
|
48
52
|
)
|
|
49
|
-
if not self.return_z_loss:
|
|
53
|
+
if not self.return_z_loss and not self.return_token_accuracy:
|
|
50
54
|
return loss
|
|
51
|
-
|
|
55
|
+
|
|
56
|
+
return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Optional
|
|
2
3
|
|
|
4
|
+
import torch
|
|
5
|
+
|
|
3
6
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
7
|
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
5
8
|
from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
|
|
@@ -22,6 +25,13 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
|
22
25
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
23
26
|
|
|
24
27
|
|
|
28
|
+
@dataclass
|
|
29
|
+
class CrossEntropyOutput:
|
|
30
|
+
loss: torch.Tensor
|
|
31
|
+
z_loss: Optional[torch.Tensor] = None
|
|
32
|
+
token_accuracy: Optional[torch.Tensor] = None
|
|
33
|
+
|
|
34
|
+
|
|
25
35
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
26
36
|
# `weight` and `size_average` are placeholders and not implemented yet
|
|
27
37
|
def liger_cross_entropy(
|
|
@@ -36,8 +46,9 @@ def liger_cross_entropy(
|
|
|
36
46
|
lse_square_scale: float = 0.0,
|
|
37
47
|
softcap: Optional[float] = None,
|
|
38
48
|
return_z_loss: bool = False,
|
|
49
|
+
return_token_accuracy: bool = False,
|
|
39
50
|
):
|
|
40
|
-
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
51
|
+
loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
|
|
41
52
|
input,
|
|
42
53
|
target,
|
|
43
54
|
weight,
|
|
@@ -47,10 +58,13 @@ def liger_cross_entropy(
|
|
|
47
58
|
reduction,
|
|
48
59
|
softcap,
|
|
49
60
|
return_z_loss,
|
|
61
|
+
return_token_accuracy,
|
|
50
62
|
)
|
|
51
|
-
|
|
63
|
+
|
|
64
|
+
if not return_z_loss and not return_token_accuracy:
|
|
52
65
|
return loss
|
|
53
|
-
|
|
66
|
+
|
|
67
|
+
return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
|
|
54
68
|
|
|
55
69
|
|
|
56
70
|
def liger_fused_linear_cross_entropy(
|
|
@@ -67,8 +81,9 @@ def liger_fused_linear_cross_entropy(
|
|
|
67
81
|
return_z_loss: bool = False,
|
|
68
82
|
accum_dtype=None,
|
|
69
83
|
use_token_scaling: bool = False,
|
|
84
|
+
return_token_accuracy: bool = False,
|
|
70
85
|
):
|
|
71
|
-
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
86
|
+
loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
|
|
72
87
|
input,
|
|
73
88
|
weight,
|
|
74
89
|
target,
|
|
@@ -82,10 +97,13 @@ def liger_fused_linear_cross_entropy(
|
|
|
82
97
|
return_z_loss,
|
|
83
98
|
accum_dtype,
|
|
84
99
|
use_token_scaling,
|
|
100
|
+
return_token_accuracy,
|
|
85
101
|
)
|
|
86
|
-
|
|
102
|
+
|
|
103
|
+
if not return_z_loss and not return_token_accuracy:
|
|
87
104
|
return loss
|
|
88
|
-
|
|
105
|
+
|
|
106
|
+
return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
|
|
89
107
|
|
|
90
108
|
|
|
91
109
|
def liger_fused_linear_jsd(
|
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
6
|
+
from liger_kernel.transformers.functional import CrossEntropyOutput
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
@@ -17,6 +18,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
17
18
|
return_z_loss: bool = False,
|
|
18
19
|
accum_dtype: Optional[torch.dtype] = None,
|
|
19
20
|
use_token_scaling: bool = False,
|
|
21
|
+
return_token_accuracy: bool = False,
|
|
20
22
|
):
|
|
21
23
|
super().__init__()
|
|
22
24
|
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
|
|
@@ -37,9 +39,10 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
37
39
|
self.return_z_loss = return_z_loss
|
|
38
40
|
self.accum_dtype = accum_dtype
|
|
39
41
|
self.use_token_scaling = use_token_scaling
|
|
42
|
+
self.return_token_accuracy = return_token_accuracy
|
|
40
43
|
|
|
41
44
|
def forward(self, lin_weight, _input, target, bias=None):
|
|
42
|
-
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
45
|
+
loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
|
|
43
46
|
_input,
|
|
44
47
|
lin_weight,
|
|
45
48
|
target,
|
|
@@ -53,7 +56,9 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
53
56
|
self.return_z_loss,
|
|
54
57
|
self.accum_dtype,
|
|
55
58
|
self.use_token_scaling,
|
|
59
|
+
self.return_token_accuracy,
|
|
56
60
|
)
|
|
57
|
-
if not self.return_z_loss:
|
|
61
|
+
if not self.return_z_loss and not self.return_token_accuracy:
|
|
58
62
|
return loss
|
|
59
|
-
|
|
63
|
+
|
|
64
|
+
return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
|
|
@@ -4,12 +4,12 @@ from typing import Union
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
-
|
|
9
7
|
if TYPE_CHECKING:
|
|
10
8
|
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
|
|
11
9
|
|
|
12
10
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def lce_forward(
|
|
@@ -26,8 +26,9 @@ def lce_forward(
|
|
|
26
26
|
cache_position: Optional[torch.LongTensor] = None,
|
|
27
27
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
28
28
|
skip_logits: Optional[bool] = None,
|
|
29
|
+
return_dict: Optional[bool] = None,
|
|
29
30
|
**kwargs,
|
|
30
|
-
) -> Union[tuple,
|
|
31
|
+
) -> Union[tuple, LigerCausalLMOutputWithPast]:
|
|
31
32
|
r"""
|
|
32
33
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
33
34
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -54,6 +55,7 @@ def lce_forward(
|
|
|
54
55
|
output_hidden_states = (
|
|
55
56
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
56
57
|
)
|
|
58
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
57
59
|
|
|
58
60
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
59
61
|
outputs = self.model(
|
|
@@ -77,6 +79,8 @@ def lce_forward(
|
|
|
77
79
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
78
80
|
logits = None
|
|
79
81
|
loss = None
|
|
82
|
+
token_accuracy = None
|
|
83
|
+
|
|
80
84
|
# if in training mode, don't materialize logits
|
|
81
85
|
if skip_logits and labels is None:
|
|
82
86
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -85,8 +89,9 @@ def lce_forward(
|
|
|
85
89
|
# By default, if in training mode, don't materialize logits
|
|
86
90
|
skip_logits = self.training and labels is not None
|
|
87
91
|
|
|
92
|
+
# Compute loss
|
|
88
93
|
if skip_logits:
|
|
89
|
-
|
|
94
|
+
result = LigerForCausalLMLoss(
|
|
90
95
|
hidden_states=kept_hidden_states,
|
|
91
96
|
lm_head_weight=self.lm_head.weight,
|
|
92
97
|
labels=labels,
|
|
@@ -94,15 +99,24 @@ def lce_forward(
|
|
|
94
99
|
hidden_size=self.config.hidden_size,
|
|
95
100
|
**kwargs,
|
|
96
101
|
)
|
|
102
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
97
103
|
else:
|
|
98
104
|
logits = self.lm_head(kept_hidden_states)
|
|
99
105
|
if labels is not None or shift_labels is not None:
|
|
100
106
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
101
107
|
|
|
102
|
-
|
|
108
|
+
if not return_dict:
|
|
109
|
+
output = (logits,) + outputs[1:]
|
|
110
|
+
output = ((loss,) + output) if loss is not None else output
|
|
111
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
112
|
+
return output
|
|
113
|
+
|
|
114
|
+
# Return custom output class with token_accuracy field
|
|
115
|
+
return LigerCausalLMOutputWithPast(
|
|
103
116
|
loss=loss,
|
|
104
117
|
logits=logits,
|
|
105
118
|
past_key_values=outputs.past_key_values,
|
|
106
119
|
hidden_states=outputs.hidden_states,
|
|
107
120
|
attentions=outputs.attentions,
|
|
121
|
+
token_accuracy=token_accuracy,
|
|
108
122
|
)
|
|
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
|
|
|
12
12
|
|
|
13
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
14
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
16
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
def lce_forward_deprecated(
|
|
@@ -147,7 +149,7 @@ def lce_forward(
|
|
|
147
149
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
148
150
|
skip_logits: Optional[bool] = None,
|
|
149
151
|
**kwargs,
|
|
150
|
-
) -> Union[Tuple,
|
|
152
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
151
153
|
r"""
|
|
152
154
|
Args:
|
|
153
155
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -209,6 +211,7 @@ def lce_forward(
|
|
|
209
211
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
210
212
|
logits = None
|
|
211
213
|
loss = None
|
|
214
|
+
token_accuracy = None
|
|
212
215
|
|
|
213
216
|
if skip_logits and labels is None and shift_labels is None:
|
|
214
217
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -217,8 +220,9 @@ def lce_forward(
|
|
|
217
220
|
# By default, if in training mode, don't materialize logits
|
|
218
221
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
219
222
|
|
|
223
|
+
# Compute loss
|
|
220
224
|
if skip_logits:
|
|
221
|
-
|
|
225
|
+
result = LigerForCausalLMLoss(
|
|
222
226
|
hidden_states=kept_hidden_states,
|
|
223
227
|
lm_head_weight=self.lm_head.weight,
|
|
224
228
|
labels=labels,
|
|
@@ -226,6 +230,7 @@ def lce_forward(
|
|
|
226
230
|
hidden_size=self.config.hidden_size,
|
|
227
231
|
**kwargs,
|
|
228
232
|
)
|
|
233
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
229
234
|
else:
|
|
230
235
|
logits = self.lm_head(kept_hidden_states)
|
|
231
236
|
if labels is not None or shift_labels is not None:
|
|
@@ -238,13 +243,19 @@ def lce_forward(
|
|
|
238
243
|
)
|
|
239
244
|
|
|
240
245
|
if not return_dict:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
246
|
+
output_tuple = (logits,) + outputs[1:]
|
|
247
|
+
if loss is not None:
|
|
248
|
+
output_tuple = (loss,) + output_tuple
|
|
249
|
+
if token_accuracy is not None:
|
|
250
|
+
output_tuple = output_tuple + (token_accuracy,)
|
|
251
|
+
return output_tuple
|
|
252
|
+
|
|
253
|
+
# Return custom output class with token_accuracy field
|
|
254
|
+
return LigerCausalLMOutputWithPast(
|
|
245
255
|
loss=loss,
|
|
246
256
|
logits=logits,
|
|
247
257
|
past_key_values=outputs.past_key_values,
|
|
248
258
|
hidden_states=outputs.hidden_states,
|
|
249
259
|
attentions=outputs.attentions,
|
|
260
|
+
token_accuracy=token_accuracy,
|
|
250
261
|
)
|