@synsci/cli-darwin-x64 1.1.49
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/skills/accelerate/SKILL.md +332 -0
- package/bin/skills/accelerate/references/custom-plugins.md +453 -0
- package/bin/skills/accelerate/references/megatron-integration.md +489 -0
- package/bin/skills/accelerate/references/performance.md +525 -0
- package/bin/skills/audiocraft/SKILL.md +564 -0
- package/bin/skills/audiocraft/references/advanced-usage.md +666 -0
- package/bin/skills/audiocraft/references/troubleshooting.md +504 -0
- package/bin/skills/autogpt/SKILL.md +403 -0
- package/bin/skills/autogpt/references/advanced-usage.md +535 -0
- package/bin/skills/autogpt/references/troubleshooting.md +420 -0
- package/bin/skills/awq/SKILL.md +310 -0
- package/bin/skills/awq/references/advanced-usage.md +324 -0
- package/bin/skills/awq/references/troubleshooting.md +344 -0
- package/bin/skills/axolotl/SKILL.md +158 -0
- package/bin/skills/axolotl/references/api.md +5548 -0
- package/bin/skills/axolotl/references/dataset-formats.md +1029 -0
- package/bin/skills/axolotl/references/index.md +15 -0
- package/bin/skills/axolotl/references/other.md +3563 -0
- package/bin/skills/bigcode-evaluation-harness/SKILL.md +405 -0
- package/bin/skills/bigcode-evaluation-harness/references/benchmarks.md +393 -0
- package/bin/skills/bigcode-evaluation-harness/references/custom-tasks.md +424 -0
- package/bin/skills/bigcode-evaluation-harness/references/issues.md +394 -0
- package/bin/skills/bitsandbytes/SKILL.md +411 -0
- package/bin/skills/bitsandbytes/references/memory-optimization.md +521 -0
- package/bin/skills/bitsandbytes/references/qlora-training.md +521 -0
- package/bin/skills/bitsandbytes/references/quantization-formats.md +447 -0
- package/bin/skills/blip-2/SKILL.md +564 -0
- package/bin/skills/blip-2/references/advanced-usage.md +680 -0
- package/bin/skills/blip-2/references/troubleshooting.md +526 -0
- package/bin/skills/chroma/SKILL.md +406 -0
- package/bin/skills/chroma/references/integration.md +38 -0
- package/bin/skills/clip/SKILL.md +253 -0
- package/bin/skills/clip/references/applications.md +207 -0
- package/bin/skills/constitutional-ai/SKILL.md +290 -0
- package/bin/skills/crewai/SKILL.md +498 -0
- package/bin/skills/crewai/references/flows.md +438 -0
- package/bin/skills/crewai/references/tools.md +429 -0
- package/bin/skills/crewai/references/troubleshooting.md +480 -0
- package/bin/skills/deepspeed/SKILL.md +141 -0
- package/bin/skills/deepspeed/references/08.md +17 -0
- package/bin/skills/deepspeed/references/09.md +173 -0
- package/bin/skills/deepspeed/references/2020.md +378 -0
- package/bin/skills/deepspeed/references/2023.md +279 -0
- package/bin/skills/deepspeed/references/assets.md +179 -0
- package/bin/skills/deepspeed/references/index.md +35 -0
- package/bin/skills/deepspeed/references/mii.md +118 -0
- package/bin/skills/deepspeed/references/other.md +1191 -0
- package/bin/skills/deepspeed/references/tutorials.md +6554 -0
- package/bin/skills/dspy/SKILL.md +590 -0
- package/bin/skills/dspy/references/examples.md +663 -0
- package/bin/skills/dspy/references/modules.md +475 -0
- package/bin/skills/dspy/references/optimizers.md +566 -0
- package/bin/skills/faiss/SKILL.md +221 -0
- package/bin/skills/faiss/references/index_types.md +280 -0
- package/bin/skills/flash-attention/SKILL.md +367 -0
- package/bin/skills/flash-attention/references/benchmarks.md +215 -0
- package/bin/skills/flash-attention/references/transformers-integration.md +293 -0
- package/bin/skills/gguf/SKILL.md +427 -0
- package/bin/skills/gguf/references/advanced-usage.md +504 -0
- package/bin/skills/gguf/references/troubleshooting.md +442 -0
- package/bin/skills/gptq/SKILL.md +450 -0
- package/bin/skills/gptq/references/calibration.md +337 -0
- package/bin/skills/gptq/references/integration.md +129 -0
- package/bin/skills/gptq/references/troubleshooting.md +95 -0
- package/bin/skills/grpo-rl-training/README.md +97 -0
- package/bin/skills/grpo-rl-training/SKILL.md +572 -0
- package/bin/skills/grpo-rl-training/examples/reward_functions_library.py +393 -0
- package/bin/skills/grpo-rl-training/templates/basic_grpo_training.py +228 -0
- package/bin/skills/guidance/SKILL.md +572 -0
- package/bin/skills/guidance/references/backends.md +554 -0
- package/bin/skills/guidance/references/constraints.md +674 -0
- package/bin/skills/guidance/references/examples.md +767 -0
- package/bin/skills/hqq/SKILL.md +445 -0
- package/bin/skills/hqq/references/advanced-usage.md +528 -0
- package/bin/skills/hqq/references/troubleshooting.md +503 -0
- package/bin/skills/hugging-face-cli/SKILL.md +191 -0
- package/bin/skills/hugging-face-cli/references/commands.md +954 -0
- package/bin/skills/hugging-face-cli/references/examples.md +374 -0
- package/bin/skills/hugging-face-datasets/SKILL.md +547 -0
- package/bin/skills/hugging-face-datasets/examples/diverse_training_examples.json +239 -0
- package/bin/skills/hugging-face-datasets/examples/system_prompt_template.txt +196 -0
- package/bin/skills/hugging-face-datasets/examples/training_examples.json +176 -0
- package/bin/skills/hugging-face-datasets/scripts/dataset_manager.py +522 -0
- package/bin/skills/hugging-face-datasets/scripts/sql_manager.py +844 -0
- package/bin/skills/hugging-face-datasets/templates/chat.json +55 -0
- package/bin/skills/hugging-face-datasets/templates/classification.json +62 -0
- package/bin/skills/hugging-face-datasets/templates/completion.json +51 -0
- package/bin/skills/hugging-face-datasets/templates/custom.json +75 -0
- package/bin/skills/hugging-face-datasets/templates/qa.json +54 -0
- package/bin/skills/hugging-face-datasets/templates/tabular.json +81 -0
- package/bin/skills/hugging-face-evaluation/SKILL.md +656 -0
- package/bin/skills/hugging-face-evaluation/examples/USAGE_EXAMPLES.md +382 -0
- package/bin/skills/hugging-face-evaluation/examples/artificial_analysis_to_hub.py +141 -0
- package/bin/skills/hugging-face-evaluation/examples/example_readme_tables.md +135 -0
- package/bin/skills/hugging-face-evaluation/examples/metric_mapping.json +50 -0
- package/bin/skills/hugging-face-evaluation/requirements.txt +20 -0
- package/bin/skills/hugging-face-evaluation/scripts/evaluation_manager.py +1374 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_eval_uv.py +104 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_vllm_uv.py +317 -0
- package/bin/skills/hugging-face-evaluation/scripts/lighteval_vllm_uv.py +303 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_eval_job.py +98 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_vllm_eval_job.py +331 -0
- package/bin/skills/hugging-face-evaluation/scripts/test_extraction.py +206 -0
- package/bin/skills/hugging-face-jobs/SKILL.md +1041 -0
- package/bin/skills/hugging-face-jobs/index.html +216 -0
- package/bin/skills/hugging-face-jobs/references/hardware_guide.md +336 -0
- package/bin/skills/hugging-face-jobs/references/hub_saving.md +352 -0
- package/bin/skills/hugging-face-jobs/references/token_usage.md +546 -0
- package/bin/skills/hugging-face-jobs/references/troubleshooting.md +475 -0
- package/bin/skills/hugging-face-jobs/scripts/cot-self-instruct.py +718 -0
- package/bin/skills/hugging-face-jobs/scripts/finepdfs-stats.py +546 -0
- package/bin/skills/hugging-face-jobs/scripts/generate-responses.py +587 -0
- package/bin/skills/hugging-face-model-trainer/SKILL.md +711 -0
- package/bin/skills/hugging-face-model-trainer/references/gguf_conversion.md +296 -0
- package/bin/skills/hugging-face-model-trainer/references/hardware_guide.md +283 -0
- package/bin/skills/hugging-face-model-trainer/references/hub_saving.md +364 -0
- package/bin/skills/hugging-face-model-trainer/references/reliability_principles.md +371 -0
- package/bin/skills/hugging-face-model-trainer/references/trackio_guide.md +189 -0
- package/bin/skills/hugging-face-model-trainer/references/training_methods.md +150 -0
- package/bin/skills/hugging-face-model-trainer/references/training_patterns.md +203 -0
- package/bin/skills/hugging-face-model-trainer/references/troubleshooting.md +282 -0
- package/bin/skills/hugging-face-model-trainer/scripts/convert_to_gguf.py +424 -0
- package/bin/skills/hugging-face-model-trainer/scripts/dataset_inspector.py +417 -0
- package/bin/skills/hugging-face-model-trainer/scripts/estimate_cost.py +150 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_dpo_example.py +106 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_grpo_example.py +89 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_sft_example.py +122 -0
- package/bin/skills/hugging-face-paper-publisher/SKILL.md +627 -0
- package/bin/skills/hugging-face-paper-publisher/examples/example_usage.md +327 -0
- package/bin/skills/hugging-face-paper-publisher/references/quick_reference.md +216 -0
- package/bin/skills/hugging-face-paper-publisher/scripts/paper_manager.py +508 -0
- package/bin/skills/hugging-face-paper-publisher/templates/arxiv.md +299 -0
- package/bin/skills/hugging-face-paper-publisher/templates/ml-report.md +358 -0
- package/bin/skills/hugging-face-paper-publisher/templates/modern.md +319 -0
- package/bin/skills/hugging-face-paper-publisher/templates/standard.md +201 -0
- package/bin/skills/hugging-face-tool-builder/SKILL.md +115 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.py +57 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.sh +40 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.tsx +57 -0
- package/bin/skills/hugging-face-tool-builder/references/find_models_by_paper.sh +230 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_enrich_models.sh +96 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_card_frontmatter.sh +188 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_papers_auth.sh +171 -0
- package/bin/skills/hugging-face-trackio/SKILL.md +65 -0
- package/bin/skills/hugging-face-trackio/references/logging_metrics.md +206 -0
- package/bin/skills/hugging-face-trackio/references/retrieving_metrics.md +223 -0
- package/bin/skills/huggingface-tokenizers/SKILL.md +516 -0
- package/bin/skills/huggingface-tokenizers/references/algorithms.md +653 -0
- package/bin/skills/huggingface-tokenizers/references/integration.md +637 -0
- package/bin/skills/huggingface-tokenizers/references/pipeline.md +723 -0
- package/bin/skills/huggingface-tokenizers/references/training.md +565 -0
- package/bin/skills/instructor/SKILL.md +740 -0
- package/bin/skills/instructor/references/examples.md +107 -0
- package/bin/skills/instructor/references/providers.md +70 -0
- package/bin/skills/instructor/references/validation.md +606 -0
- package/bin/skills/knowledge-distillation/SKILL.md +458 -0
- package/bin/skills/knowledge-distillation/references/minillm.md +334 -0
- package/bin/skills/lambda-labs/SKILL.md +545 -0
- package/bin/skills/lambda-labs/references/advanced-usage.md +611 -0
- package/bin/skills/lambda-labs/references/troubleshooting.md +530 -0
- package/bin/skills/langchain/SKILL.md +480 -0
- package/bin/skills/langchain/references/agents.md +499 -0
- package/bin/skills/langchain/references/integration.md +562 -0
- package/bin/skills/langchain/references/rag.md +600 -0
- package/bin/skills/langsmith/SKILL.md +422 -0
- package/bin/skills/langsmith/references/advanced-usage.md +548 -0
- package/bin/skills/langsmith/references/troubleshooting.md +537 -0
- package/bin/skills/litgpt/SKILL.md +469 -0
- package/bin/skills/litgpt/references/custom-models.md +568 -0
- package/bin/skills/litgpt/references/distributed-training.md +451 -0
- package/bin/skills/litgpt/references/supported-models.md +336 -0
- package/bin/skills/litgpt/references/training-recipes.md +619 -0
- package/bin/skills/llama-cpp/SKILL.md +258 -0
- package/bin/skills/llama-cpp/references/optimization.md +89 -0
- package/bin/skills/llama-cpp/references/quantization.md +213 -0
- package/bin/skills/llama-cpp/references/server.md +125 -0
- package/bin/skills/llama-factory/SKILL.md +80 -0
- package/bin/skills/llama-factory/references/_images.md +23 -0
- package/bin/skills/llama-factory/references/advanced.md +1055 -0
- package/bin/skills/llama-factory/references/getting_started.md +349 -0
- package/bin/skills/llama-factory/references/index.md +19 -0
- package/bin/skills/llama-factory/references/other.md +31 -0
- package/bin/skills/llamaguard/SKILL.md +337 -0
- package/bin/skills/llamaindex/SKILL.md +569 -0
- package/bin/skills/llamaindex/references/agents.md +83 -0
- package/bin/skills/llamaindex/references/data_connectors.md +108 -0
- package/bin/skills/llamaindex/references/query_engines.md +406 -0
- package/bin/skills/llava/SKILL.md +304 -0
- package/bin/skills/llava/references/training.md +197 -0
- package/bin/skills/lm-evaluation-harness/SKILL.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
- package/bin/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
- package/bin/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
- package/bin/skills/long-context/SKILL.md +536 -0
- package/bin/skills/long-context/references/extension_methods.md +468 -0
- package/bin/skills/long-context/references/fine_tuning.md +611 -0
- package/bin/skills/long-context/references/rope.md +402 -0
- package/bin/skills/mamba/SKILL.md +260 -0
- package/bin/skills/mamba/references/architecture-details.md +206 -0
- package/bin/skills/mamba/references/benchmarks.md +255 -0
- package/bin/skills/mamba/references/training-guide.md +388 -0
- package/bin/skills/megatron-core/SKILL.md +366 -0
- package/bin/skills/megatron-core/references/benchmarks.md +249 -0
- package/bin/skills/megatron-core/references/parallelism-guide.md +404 -0
- package/bin/skills/megatron-core/references/production-examples.md +473 -0
- package/bin/skills/megatron-core/references/training-recipes.md +547 -0
- package/bin/skills/miles/SKILL.md +315 -0
- package/bin/skills/miles/references/api-reference.md +141 -0
- package/bin/skills/miles/references/troubleshooting.md +352 -0
- package/bin/skills/mlflow/SKILL.md +704 -0
- package/bin/skills/mlflow/references/deployment.md +744 -0
- package/bin/skills/mlflow/references/model-registry.md +770 -0
- package/bin/skills/mlflow/references/tracking.md +680 -0
- package/bin/skills/modal/SKILL.md +341 -0
- package/bin/skills/modal/references/advanced-usage.md +503 -0
- package/bin/skills/modal/references/troubleshooting.md +494 -0
- package/bin/skills/model-merging/SKILL.md +539 -0
- package/bin/skills/model-merging/references/evaluation.md +462 -0
- package/bin/skills/model-merging/references/examples.md +428 -0
- package/bin/skills/model-merging/references/methods.md +352 -0
- package/bin/skills/model-pruning/SKILL.md +495 -0
- package/bin/skills/model-pruning/references/wanda.md +347 -0
- package/bin/skills/moe-training/SKILL.md +526 -0
- package/bin/skills/moe-training/references/architectures.md +432 -0
- package/bin/skills/moe-training/references/inference.md +348 -0
- package/bin/skills/moe-training/references/training.md +425 -0
- package/bin/skills/nanogpt/SKILL.md +290 -0
- package/bin/skills/nanogpt/references/architecture.md +382 -0
- package/bin/skills/nanogpt/references/data.md +476 -0
- package/bin/skills/nanogpt/references/training.md +564 -0
- package/bin/skills/nemo-curator/SKILL.md +383 -0
- package/bin/skills/nemo-curator/references/deduplication.md +87 -0
- package/bin/skills/nemo-curator/references/filtering.md +102 -0
- package/bin/skills/nemo-evaluator/SKILL.md +494 -0
- package/bin/skills/nemo-evaluator/references/adapter-system.md +340 -0
- package/bin/skills/nemo-evaluator/references/configuration.md +447 -0
- package/bin/skills/nemo-evaluator/references/custom-benchmarks.md +315 -0
- package/bin/skills/nemo-evaluator/references/execution-backends.md +361 -0
- package/bin/skills/nemo-guardrails/SKILL.md +297 -0
- package/bin/skills/nnsight/SKILL.md +436 -0
- package/bin/skills/nnsight/references/README.md +78 -0
- package/bin/skills/nnsight/references/api.md +344 -0
- package/bin/skills/nnsight/references/tutorials.md +300 -0
- package/bin/skills/openrlhf/SKILL.md +249 -0
- package/bin/skills/openrlhf/references/algorithm-comparison.md +404 -0
- package/bin/skills/openrlhf/references/custom-rewards.md +530 -0
- package/bin/skills/openrlhf/references/hybrid-engine.md +287 -0
- package/bin/skills/openrlhf/references/multi-node-training.md +454 -0
- package/bin/skills/outlines/SKILL.md +652 -0
- package/bin/skills/outlines/references/backends.md +615 -0
- package/bin/skills/outlines/references/examples.md +773 -0
- package/bin/skills/outlines/references/json_generation.md +652 -0
- package/bin/skills/peft/SKILL.md +431 -0
- package/bin/skills/peft/references/advanced-usage.md +514 -0
- package/bin/skills/peft/references/troubleshooting.md +480 -0
- package/bin/skills/phoenix/SKILL.md +475 -0
- package/bin/skills/phoenix/references/advanced-usage.md +619 -0
- package/bin/skills/phoenix/references/troubleshooting.md +538 -0
- package/bin/skills/pinecone/SKILL.md +358 -0
- package/bin/skills/pinecone/references/deployment.md +181 -0
- package/bin/skills/pytorch-fsdp/SKILL.md +126 -0
- package/bin/skills/pytorch-fsdp/references/index.md +7 -0
- package/bin/skills/pytorch-fsdp/references/other.md +4249 -0
- package/bin/skills/pytorch-lightning/SKILL.md +346 -0
- package/bin/skills/pytorch-lightning/references/callbacks.md +436 -0
- package/bin/skills/pytorch-lightning/references/distributed.md +490 -0
- package/bin/skills/pytorch-lightning/references/hyperparameter-tuning.md +556 -0
- package/bin/skills/pyvene/SKILL.md +473 -0
- package/bin/skills/pyvene/references/README.md +73 -0
- package/bin/skills/pyvene/references/api.md +383 -0
- package/bin/skills/pyvene/references/tutorials.md +376 -0
- package/bin/skills/qdrant/SKILL.md +493 -0
- package/bin/skills/qdrant/references/advanced-usage.md +648 -0
- package/bin/skills/qdrant/references/troubleshooting.md +631 -0
- package/bin/skills/ray-data/SKILL.md +326 -0
- package/bin/skills/ray-data/references/integration.md +82 -0
- package/bin/skills/ray-data/references/transformations.md +83 -0
- package/bin/skills/ray-train/SKILL.md +406 -0
- package/bin/skills/ray-train/references/multi-node.md +628 -0
- package/bin/skills/rwkv/SKILL.md +260 -0
- package/bin/skills/rwkv/references/architecture-details.md +344 -0
- package/bin/skills/rwkv/references/rwkv7.md +386 -0
- package/bin/skills/rwkv/references/state-management.md +369 -0
- package/bin/skills/saelens/SKILL.md +386 -0
- package/bin/skills/saelens/references/README.md +70 -0
- package/bin/skills/saelens/references/api.md +333 -0
- package/bin/skills/saelens/references/tutorials.md +318 -0
- package/bin/skills/segment-anything/SKILL.md +500 -0
- package/bin/skills/segment-anything/references/advanced-usage.md +589 -0
- package/bin/skills/segment-anything/references/troubleshooting.md +484 -0
- package/bin/skills/sentence-transformers/SKILL.md +255 -0
- package/bin/skills/sentence-transformers/references/models.md +123 -0
- package/bin/skills/sentencepiece/SKILL.md +235 -0
- package/bin/skills/sentencepiece/references/algorithms.md +200 -0
- package/bin/skills/sentencepiece/references/training.md +304 -0
- package/bin/skills/sglang/SKILL.md +442 -0
- package/bin/skills/sglang/references/deployment.md +490 -0
- package/bin/skills/sglang/references/radix-attention.md +413 -0
- package/bin/skills/sglang/references/structured-generation.md +541 -0
- package/bin/skills/simpo/SKILL.md +219 -0
- package/bin/skills/simpo/references/datasets.md +478 -0
- package/bin/skills/simpo/references/hyperparameters.md +452 -0
- package/bin/skills/simpo/references/loss-functions.md +350 -0
- package/bin/skills/skypilot/SKILL.md +509 -0
- package/bin/skills/skypilot/references/advanced-usage.md +491 -0
- package/bin/skills/skypilot/references/troubleshooting.md +570 -0
- package/bin/skills/slime/SKILL.md +464 -0
- package/bin/skills/slime/references/api-reference.md +392 -0
- package/bin/skills/slime/references/troubleshooting.md +386 -0
- package/bin/skills/speculative-decoding/SKILL.md +467 -0
- package/bin/skills/speculative-decoding/references/lookahead.md +309 -0
- package/bin/skills/speculative-decoding/references/medusa.md +350 -0
- package/bin/skills/stable-diffusion/SKILL.md +519 -0
- package/bin/skills/stable-diffusion/references/advanced-usage.md +716 -0
- package/bin/skills/stable-diffusion/references/troubleshooting.md +555 -0
- package/bin/skills/tensorboard/SKILL.md +629 -0
- package/bin/skills/tensorboard/references/integrations.md +638 -0
- package/bin/skills/tensorboard/references/profiling.md +545 -0
- package/bin/skills/tensorboard/references/visualization.md +620 -0
- package/bin/skills/tensorrt-llm/SKILL.md +187 -0
- package/bin/skills/tensorrt-llm/references/multi-gpu.md +298 -0
- package/bin/skills/tensorrt-llm/references/optimization.md +242 -0
- package/bin/skills/tensorrt-llm/references/serving.md +470 -0
- package/bin/skills/tinker/SKILL.md +362 -0
- package/bin/skills/tinker/references/api-reference.md +168 -0
- package/bin/skills/tinker/references/getting-started.md +157 -0
- package/bin/skills/tinker/references/loss-functions.md +163 -0
- package/bin/skills/tinker/references/models-and-lora.md +139 -0
- package/bin/skills/tinker/references/recipes.md +280 -0
- package/bin/skills/tinker/references/reinforcement-learning.md +212 -0
- package/bin/skills/tinker/references/rendering.md +243 -0
- package/bin/skills/tinker/references/supervised-learning.md +232 -0
- package/bin/skills/tinker-training-cost/SKILL.md +187 -0
- package/bin/skills/tinker-training-cost/scripts/calculate_cost.py +123 -0
- package/bin/skills/torchforge/SKILL.md +433 -0
- package/bin/skills/torchforge/references/api-reference.md +327 -0
- package/bin/skills/torchforge/references/troubleshooting.md +409 -0
- package/bin/skills/torchtitan/SKILL.md +358 -0
- package/bin/skills/torchtitan/references/checkpoint.md +181 -0
- package/bin/skills/torchtitan/references/custom-models.md +258 -0
- package/bin/skills/torchtitan/references/float8.md +133 -0
- package/bin/skills/torchtitan/references/fsdp.md +126 -0
- package/bin/skills/transformer-lens/SKILL.md +346 -0
- package/bin/skills/transformer-lens/references/README.md +54 -0
- package/bin/skills/transformer-lens/references/api.md +362 -0
- package/bin/skills/transformer-lens/references/tutorials.md +339 -0
- package/bin/skills/trl-fine-tuning/SKILL.md +455 -0
- package/bin/skills/trl-fine-tuning/references/dpo-variants.md +227 -0
- package/bin/skills/trl-fine-tuning/references/online-rl.md +82 -0
- package/bin/skills/trl-fine-tuning/references/reward-modeling.md +122 -0
- package/bin/skills/trl-fine-tuning/references/sft-training.md +168 -0
- package/bin/skills/unsloth/SKILL.md +80 -0
- package/bin/skills/unsloth/references/index.md +7 -0
- package/bin/skills/unsloth/references/llms-full.md +16799 -0
- package/bin/skills/unsloth/references/llms-txt.md +12044 -0
- package/bin/skills/unsloth/references/llms.md +82 -0
- package/bin/skills/verl/SKILL.md +391 -0
- package/bin/skills/verl/references/api-reference.md +301 -0
- package/bin/skills/verl/references/troubleshooting.md +391 -0
- package/bin/skills/vllm/SKILL.md +364 -0
- package/bin/skills/vllm/references/optimization.md +226 -0
- package/bin/skills/vllm/references/quantization.md +284 -0
- package/bin/skills/vllm/references/server-deployment.md +255 -0
- package/bin/skills/vllm/references/troubleshooting.md +447 -0
- package/bin/skills/weights-and-biases/SKILL.md +590 -0
- package/bin/skills/weights-and-biases/references/artifacts.md +584 -0
- package/bin/skills/weights-and-biases/references/integrations.md +700 -0
- package/bin/skills/weights-and-biases/references/sweeps.md +847 -0
- package/bin/skills/whisper/SKILL.md +317 -0
- package/bin/skills/whisper/references/languages.md +189 -0
- package/bin/synsc +0 -0
- package/package.json +10 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
# Lookahead Decoding: Jacobi Iteration
|
|
2
|
+
|
|
3
|
+
Based on ICML 2024 paper and LMSYS blog post
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
**Source**: https://lmsys.org/blog/2023-11-21-lookahead-decoding/
|
|
8
|
+
**Paper**: ICML 2024
|
|
9
|
+
**GitHub**: https://github.com/hao-ai-lab/LookaheadDecoding
|
|
10
|
+
|
|
11
|
+
Lookahead Decoding breaks sequential dependency in autoregressive decoding using Jacobi iteration, achieving 1.5-2.3× speedup without draft models or training.
|
|
12
|
+
|
|
13
|
+
## Core Concept
|
|
14
|
+
|
|
15
|
+
### Reformulation as Equation Solving
|
|
16
|
+
|
|
17
|
+
**Traditional autoregressive**:
|
|
18
|
+
```
|
|
19
|
+
y_t = f(x, y_1, y_2, ..., y_{t-1}) # Sequential
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
**Jacobi iteration**:
|
|
23
|
+
```
|
|
24
|
+
y_t^{(k+1)} = f(x, y_1^{(k)}, y_2^{(k)}, ..., y_{t-1}^{(k)}) # Parallel
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
**Key insight**: Although exact parallel decoding is impossible, we can generate multiple disjoint n-grams in parallel that may fit into the final sequence.
|
|
28
|
+
|
|
29
|
+
## Two-Branch Architecture
|
|
30
|
+
|
|
31
|
+
### Lookahead Branch
|
|
32
|
+
|
|
33
|
+
**Purpose**: Generate potential token sequences (n-grams) in parallel.
|
|
34
|
+
|
|
35
|
+
**Parameters**:
|
|
36
|
+
- `W` (window size): How many steps ahead to look
|
|
37
|
+
- `N` (n-gram size): How many past tokens to use for generation
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
# Example: W=5, N=3
|
|
41
|
+
# Generate n-grams at positions 1-5 using past 1-3 tokens
|
|
42
|
+
|
|
43
|
+
def lookahead_branch(model, tokens, W=5, N=3):
|
|
44
|
+
"""Generate n-grams using Jacobi iteration."""
|
|
45
|
+
candidates = {}
|
|
46
|
+
|
|
47
|
+
for w in range(1, W + 1): # Position offset
|
|
48
|
+
for n in range(1, N + 1): # N-gram length
|
|
49
|
+
# Use n past tokens to predict at position w
|
|
50
|
+
past_tokens = tokens[-n:]
|
|
51
|
+
future_position = len(tokens) + w
|
|
52
|
+
|
|
53
|
+
# Generate n-gram
|
|
54
|
+
ngram = model.generate_ngram(
|
|
55
|
+
context=past_tokens,
|
|
56
|
+
position=future_position,
|
|
57
|
+
length=n
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
candidates[(w, n)] = ngram
|
|
61
|
+
|
|
62
|
+
return candidates
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
**Output**: Pool of candidate n-grams that might match future sequence.
|
|
66
|
+
|
|
67
|
+
### Verification Branch
|
|
68
|
+
|
|
69
|
+
**Purpose**: Identify and confirm promising n-grams.
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
def verification_branch(model, tokens, candidates):
|
|
73
|
+
"""Verify which candidates match actual sequence."""
|
|
74
|
+
verified = []
|
|
75
|
+
|
|
76
|
+
for ngram in candidates:
|
|
77
|
+
# Check if ngram's first token matches last generated token
|
|
78
|
+
if ngram[0] == tokens[-1]:
|
|
79
|
+
# Verify full n-gram with model
|
|
80
|
+
is_valid = model.verify_sequence(tokens + ngram)
|
|
81
|
+
|
|
82
|
+
if is_valid:
|
|
83
|
+
verified.append(ngram)
|
|
84
|
+
|
|
85
|
+
# Return longest verified n-gram
|
|
86
|
+
return max(verified, key=len) if verified else None
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
**Acceptance**: N-gram accepted if its first token matches the last input token and model confirms the sequence.
|
|
90
|
+
|
|
91
|
+
## Algorithm
|
|
92
|
+
|
|
93
|
+
### Complete Lookahead Decoding
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
class LookaheadDecoding:
|
|
97
|
+
def __init__(self, model, W=15, N=5, G=5):
|
|
98
|
+
"""
|
|
99
|
+
Args:
|
|
100
|
+
W: Window size (lookahead distance)
|
|
101
|
+
N: N-gram size (context length)
|
|
102
|
+
G: Guess size (parallel candidates)
|
|
103
|
+
"""
|
|
104
|
+
self.model = model
|
|
105
|
+
self.W = W
|
|
106
|
+
self.N = N
|
|
107
|
+
self.G = G
|
|
108
|
+
|
|
109
|
+
def generate(self, input_ids, max_new_tokens=256):
|
|
110
|
+
tokens = input_ids.clone()
|
|
111
|
+
|
|
112
|
+
while len(tokens) < max_new_tokens:
|
|
113
|
+
# 1. Lookahead: Generate candidates
|
|
114
|
+
candidates = self._lookahead_step(tokens)
|
|
115
|
+
|
|
116
|
+
# 2. Verification: Find matching n-grams
|
|
117
|
+
accepted_ngram = self._verification_step(tokens, candidates)
|
|
118
|
+
|
|
119
|
+
if accepted_ngram is not None:
|
|
120
|
+
# Accept multiple tokens
|
|
121
|
+
tokens = torch.cat([tokens, accepted_ngram])
|
|
122
|
+
else:
|
|
123
|
+
# Fallback: Generate single token
|
|
124
|
+
next_token = self.model.generate_next(tokens)
|
|
125
|
+
tokens = torch.cat([tokens, next_token])
|
|
126
|
+
|
|
127
|
+
return tokens
|
|
128
|
+
|
|
129
|
+
def _lookahead_step(self, tokens):
|
|
130
|
+
"""Generate candidate n-grams in parallel."""
|
|
131
|
+
candidates = []
|
|
132
|
+
|
|
133
|
+
for w in range(1, self.W + 1):
|
|
134
|
+
for n in range(1, self.N + 1):
|
|
135
|
+
# Sample n-gram from model
|
|
136
|
+
ngram = self.model.sample_ngram(
|
|
137
|
+
tokens=tokens,
|
|
138
|
+
offset=w,
|
|
139
|
+
context_size=n,
|
|
140
|
+
num_samples=self.G
|
|
141
|
+
)
|
|
142
|
+
candidates.extend(ngram)
|
|
143
|
+
|
|
144
|
+
return candidates
|
|
145
|
+
|
|
146
|
+
def _verification_step(self, tokens, candidates):
|
|
147
|
+
"""Verify candidates and select best."""
|
|
148
|
+
valid_ngrams = []
|
|
149
|
+
|
|
150
|
+
for ngram in candidates:
|
|
151
|
+
# Must match continuation
|
|
152
|
+
if ngram[0] == self._get_next_token_prediction(tokens):
|
|
153
|
+
# Verify full sequence
|
|
154
|
+
if self._verify_ngram(tokens, ngram):
|
|
155
|
+
valid_ngrams.append(ngram)
|
|
156
|
+
|
|
157
|
+
# Return longest valid n-gram
|
|
158
|
+
return max(valid_ngrams, key=len) if valid_ngrams else None
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
## Performance Analysis
|
|
162
|
+
|
|
163
|
+
### Speedup vs Parameters
|
|
164
|
+
|
|
165
|
+
**From paper (7B model on HumanEval)**:
|
|
166
|
+
|
|
167
|
+
| Window (W) | N-gram (N) | Speedup | Throughput |
|
|
168
|
+
|------------|------------|---------|------------|
|
|
169
|
+
| 5 | 3 | 1.5× | 45 tokens/sec |
|
|
170
|
+
| 10 | 5 | 1.8× | 54 tokens/sec |
|
|
171
|
+
| 15 | 5 | 2.2× | 66 tokens/sec |
|
|
172
|
+
| 20 | 7 | 2.3× | 69 tokens/sec |
|
|
173
|
+
|
|
174
|
+
**Hardware configurations (A100 GPU)**:
|
|
175
|
+
|
|
176
|
+
| Model Size | Recommended W | Recommended N |
|
|
177
|
+
|------------|---------------|---------------|
|
|
178
|
+
| 7B | 15 | 5 |
|
|
179
|
+
| 13B | 10 | 5 |
|
|
180
|
+
| 33B | 7 | 5 |
|
|
181
|
+
| 70B | 5 | 3 |
|
|
182
|
+
|
|
183
|
+
**Rule**: Larger models → smaller W, N (more expensive to verify)
|
|
184
|
+
|
|
185
|
+
### Scaling Law
|
|
186
|
+
|
|
187
|
+
**Key finding from paper**:
|
|
188
|
+
|
|
189
|
+
"When n-gram size is sufficiently large, exponentially increasing future token guesses can linearly reduce decoding steps."
|
|
190
|
+
|
|
191
|
+
```
|
|
192
|
+
Speedup ≈ 1 + (W × acceptance_rate)
|
|
193
|
+
|
|
194
|
+
where acceptance_rate depends on:
|
|
195
|
+
- Model quality (better models = higher acceptance)
|
|
196
|
+
- Task type (code generation > chat)
|
|
197
|
+
- N-gram size (larger N = higher acceptance but more compute)
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
## Hyperparameter Tuning
|
|
201
|
+
|
|
202
|
+
### Window Size (W)
|
|
203
|
+
|
|
204
|
+
```python
|
|
205
|
+
# Trade-off: Larger W = more candidates but more verification cost
|
|
206
|
+
|
|
207
|
+
W = 5 # Conservative (low overhead, moderate speedup)
|
|
208
|
+
W = 10 # Balanced
|
|
209
|
+
W = 15 # Standard (from paper, 7B models)
|
|
210
|
+
W = 20 # Aggressive (diminishing returns)
|
|
211
|
+
|
|
212
|
+
# Rule: W should be ~2-3× average token acceptance length
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
### N-gram Size (N)
|
|
216
|
+
|
|
217
|
+
```python
|
|
218
|
+
# Trade-off: Larger N = better context but slower generation
|
|
219
|
+
|
|
220
|
+
N = 3 # Fast generation, less context
|
|
221
|
+
N = 5 # Standard (from paper)
|
|
222
|
+
N = 7 # Better context, slower
|
|
223
|
+
|
|
224
|
+
# Rule: N should be large enough to capture local patterns
|
|
225
|
+
```
|
|
226
|
+
|
|
227
|
+
### Guess Size (G)
|
|
228
|
+
|
|
229
|
+
```python
|
|
230
|
+
# Number of parallel n-gram candidates per position
|
|
231
|
+
|
|
232
|
+
G = 1 # Deterministic (fastest, lower acceptance)
|
|
233
|
+
G = 5 # Standard (good balance)
|
|
234
|
+
G = 10 # More exploration (higher acceptance, more compute)
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
## Comparison with Other Methods
|
|
238
|
+
|
|
239
|
+
| Method | Speedup | Training | Draft Model | Memory |
|
|
240
|
+
|--------|---------|----------|-------------|---------|
|
|
241
|
+
| **Lookahead** | 1.5-2.3× | None | No | Base only |
|
|
242
|
+
| Draft Speculative | 1.5-2× | None | Yes | Base + draft |
|
|
243
|
+
| Medusa | 2-3.6× | Minimal | No | Base + heads |
|
|
244
|
+
|
|
245
|
+
**Advantages of Lookahead**:
|
|
246
|
+
- Zero training required
|
|
247
|
+
- No draft model needed
|
|
248
|
+
- Works out-of-the-box with any model
|
|
249
|
+
- No model modification
|
|
250
|
+
|
|
251
|
+
**Disadvantages**:
|
|
252
|
+
- Lower speedup than Medusa
|
|
253
|
+
- More complex implementation
|
|
254
|
+
- Verification overhead
|
|
255
|
+
|
|
256
|
+
## Task-Specific Performance
|
|
257
|
+
|
|
258
|
+
**From paper**:
|
|
259
|
+
|
|
260
|
+
| Task | Baseline | Lookahead | Speedup |
|
|
261
|
+
|------|----------|-----------|---------|
|
|
262
|
+
| **HumanEval (code)** | 30 tok/s | 69 tok/s | 2.3× |
|
|
263
|
+
| **MT-Bench (chat)** | 35 tok/s | 56 tok/s | 1.6× |
|
|
264
|
+
| **GSM8K (math)** | 32 tok/s | 54 tok/s | 1.7× |
|
|
265
|
+
|
|
266
|
+
**Why code is faster**: Higher n-gram predictability (syntax, patterns).
|
|
267
|
+
|
|
268
|
+
## Production Deployment
|
|
269
|
+
|
|
270
|
+
### Integration Example
|
|
271
|
+
|
|
272
|
+
```python
|
|
273
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
274
|
+
|
|
275
|
+
# Load model
|
|
276
|
+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
277
|
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
278
|
+
|
|
279
|
+
# Initialize Lookahead
|
|
280
|
+
lookahead = LookaheadDecoding(
|
|
281
|
+
model=model,
|
|
282
|
+
W=15, # Window size
|
|
283
|
+
N=5, # N-gram size
|
|
284
|
+
G=5 # Guess size
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Generate
|
|
288
|
+
prompt = "Write a Python function to calculate fibonacci:"
|
|
289
|
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
|
290
|
+
|
|
291
|
+
output = lookahead.generate(input_ids, max_new_tokens=256)
|
|
292
|
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
293
|
+
|
|
294
|
+
print(response)
|
|
295
|
+
```
|
|
296
|
+
|
|
297
|
+
### Optimization Tips
|
|
298
|
+
|
|
299
|
+
1. **Batch processing**: Verify multiple n-grams in single forward pass
|
|
300
|
+
2. **Caching**: Reuse KV cache across verification steps
|
|
301
|
+
3. **Early stopping**: Stop generation when no candidates match
|
|
302
|
+
4. **Adaptive parameters**: Adjust W, N based on acceptance rate
|
|
303
|
+
|
|
304
|
+
## Resources
|
|
305
|
+
|
|
306
|
+
- **Blog Post**: https://lmsys.org/blog/2023-11-21-lookahead-decoding/
|
|
307
|
+
- **GitHub**: https://github.com/hao-ai-lab/LookaheadDecoding
|
|
308
|
+
- **Paper**: ICML 2024 (Break the Sequential Dependency of LLM Inference Using Lookahead Decoding)
|
|
309
|
+
- **NVIDIA Blog**: https://developer.nvidia.com/blog/optimizing-qwen2-5-coder-throughput-with-nvidia-tensorrt-llm-lookahead-decoding/
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
# Medusa: Multiple Decoding Heads
|
|
2
|
+
|
|
3
|
+
Based on arXiv 2401.10774 (2024) - MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
**Source**: https://arxiv.org/abs/2401.10774
|
|
8
|
+
**GitHub**: https://github.com/FasterDecoding/Medusa
|
|
9
|
+
|
|
10
|
+
Medusa augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel, achieving 2.2-3.6× speedup without quality loss.
|
|
11
|
+
|
|
12
|
+
## Architecture
|
|
13
|
+
|
|
14
|
+
### Core Innovation
|
|
15
|
+
|
|
16
|
+
Instead of separate draft model, add multiple prediction heads to existing LLM:
|
|
17
|
+
|
|
18
|
+
```
|
|
19
|
+
Input → Base LLM (frozen or fine-tuned) → Hidden State
|
|
20
|
+
├→ Head 0 (original, predicts t+1)
|
|
21
|
+
├→ Head 1 (predicts t+2)
|
|
22
|
+
├→ Head 2 (predicts t+3)
|
|
23
|
+
└→ Head 3 (predicts t+4)
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
### Tree-Based Attention
|
|
27
|
+
|
|
28
|
+
**Key mechanism**: Construct candidate tree, verify all paths in single forward pass.
|
|
29
|
+
|
|
30
|
+
Example with 2 heads, top-2 candidates per head:
|
|
31
|
+
|
|
32
|
+
```
|
|
33
|
+
Root (current token)
|
|
34
|
+
/ \
|
|
35
|
+
Candidate 1a Candidate 1b (Head 1: 2 options)
|
|
36
|
+
/ \ / \
|
|
37
|
+
C2a C2b C2c C2d (Head 2: 4 total paths)
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
Single forward pass evaluates entire tree (4 candidates) in parallel!
|
|
41
|
+
|
|
42
|
+
## Training Methods
|
|
43
|
+
|
|
44
|
+
### Medusa-1: Frozen Backbone
|
|
45
|
+
|
|
46
|
+
**Approach**: Keep base LLM frozen, train only Medusa heads.
|
|
47
|
+
|
|
48
|
+
**Advantages**:
|
|
49
|
+
- Lossless (base model unchanged)
|
|
50
|
+
- Fast training (~few hours on 8 GPUs)
|
|
51
|
+
- Minimal data needed (~10M tokens)
|
|
52
|
+
|
|
53
|
+
**Performance**: 2.2× speedup
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
# Training loop for Medusa-1
|
|
57
|
+
for batch in dataloader:
|
|
58
|
+
# Frozen base model
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]
|
|
61
|
+
|
|
62
|
+
# Train Medusa heads
|
|
63
|
+
for i, head in enumerate(medusa_heads):
|
|
64
|
+
logits = head(hidden_states)
|
|
65
|
+
# Target: tokens shifted by (i+1) positions
|
|
66
|
+
targets = batch['input_ids'][:, i+1:]
|
|
67
|
+
loss += F.cross_entropy(logits[:, :-i-1], targets)
|
|
68
|
+
|
|
69
|
+
loss.backward()
|
|
70
|
+
optimizer.step()
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
**Training Data**: Any text corpus (Wikipedia, C4, etc.)
|
|
74
|
+
|
|
75
|
+
### Medusa-2: Joint Fine-Tuning
|
|
76
|
+
|
|
77
|
+
**Approach**: Fine-tune base LLM + Medusa heads together.
|
|
78
|
+
|
|
79
|
+
**Advantages**:
|
|
80
|
+
- Better prediction accuracy (heads aligned with base)
|
|
81
|
+
- Higher speedup (2.3-3.6×)
|
|
82
|
+
|
|
83
|
+
**Challenge**: Must preserve base model capabilities
|
|
84
|
+
|
|
85
|
+
**Solution**: Special training recipe:
|
|
86
|
+
1. Start with pre-trained base model
|
|
87
|
+
2. Add Medusa heads
|
|
88
|
+
3. Fine-tune both together with careful LR scheduling
|
|
89
|
+
4. Use high-quality data to avoid degradation
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
# Medusa-2 training
|
|
93
|
+
# All parameters trainable
|
|
94
|
+
for param in base_model.parameters():
|
|
95
|
+
param.requires_grad = True # Unfreeze base
|
|
96
|
+
|
|
97
|
+
for param in medusa_heads.parameters():
|
|
98
|
+
param.requires_grad = True
|
|
99
|
+
|
|
100
|
+
# Different learning rates
|
|
101
|
+
optimizer = torch.optim.AdamW([
|
|
102
|
+
{'params': base_model.parameters(), 'lr': 1e-5}, # Lower for base
|
|
103
|
+
{'params': medusa_heads.parameters(), 'lr': 1e-3}, # Higher for heads
|
|
104
|
+
])
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
**Performance**: 2.3-3.6× speedup
|
|
108
|
+
|
|
109
|
+
## Inference Algorithm
|
|
110
|
+
|
|
111
|
+
### Candidate Generation
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
def medusa_generate_candidates(base_logits, medusa_head_logits, top_k=10):
|
|
115
|
+
"""Generate candidate sequences using tree structure."""
|
|
116
|
+
candidates = []
|
|
117
|
+
|
|
118
|
+
# Base token (original LLM output)
|
|
119
|
+
base_token = torch.argmax(base_logits, dim=-1)
|
|
120
|
+
|
|
121
|
+
# For each Medusa head, get top-k predictions
|
|
122
|
+
medusa_candidates = []
|
|
123
|
+
for head_logits in medusa_head_logits:
|
|
124
|
+
top_k_tokens = torch.topk(head_logits, k=top_k, dim=-1).indices
|
|
125
|
+
medusa_candidates.append(top_k_tokens)
|
|
126
|
+
|
|
127
|
+
# Build candidate tree (all combinations)
|
|
128
|
+
# With 4 heads, top-2 each: 2^4 = 16 candidates
|
|
129
|
+
for combo in itertools.product(*medusa_candidates):
|
|
130
|
+
candidate = [base_token] + list(combo)
|
|
131
|
+
candidates.append(candidate)
|
|
132
|
+
|
|
133
|
+
return candidates # Shape: (num_candidates, seq_len)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
### Tree Verification
|
|
137
|
+
|
|
138
|
+
```python
|
|
139
|
+
def medusa_verify_candidates(model, candidates, past_key_values):
|
|
140
|
+
"""Verify all candidates in single forward pass using tree attention."""
|
|
141
|
+
# Construct tree attention mask
|
|
142
|
+
# All candidates share prefix, diverge at different points
|
|
143
|
+
attention_mask = build_tree_attention_mask(candidates)
|
|
144
|
+
|
|
145
|
+
# Single forward pass for all candidates
|
|
146
|
+
outputs = model(
|
|
147
|
+
input_ids=candidates,
|
|
148
|
+
attention_mask=attention_mask,
|
|
149
|
+
past_key_values=past_key_values,
|
|
150
|
+
use_cache=True
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Score each candidate
|
|
154
|
+
scores = compute_acceptance_scores(outputs.logits, candidates)
|
|
155
|
+
|
|
156
|
+
# Accept longest valid candidate
|
|
157
|
+
best_candidate = select_best(candidates, scores)
|
|
158
|
+
|
|
159
|
+
return best_candidate
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
### Acceptance Criterion
|
|
163
|
+
|
|
164
|
+
**Posterior threshold**: Accept token if probability exceeds threshold.
|
|
165
|
+
|
|
166
|
+
```python
|
|
167
|
+
def should_accept(token, token_prob, threshold=0.09):
|
|
168
|
+
"""Medusa acceptance criterion."""
|
|
169
|
+
return token_prob >= threshold
|
|
170
|
+
|
|
171
|
+
# Typical thresholds:
|
|
172
|
+
# - 0.09: Standard (from paper)
|
|
173
|
+
# - 0.05: Conservative (fewer rejections, slower)
|
|
174
|
+
# - 0.15: Aggressive (more rejections, faster when works)
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
## Performance Results
|
|
178
|
+
|
|
179
|
+
**From paper (Vicuna-7B, MT-Bench):**
|
|
180
|
+
|
|
181
|
+
| Configuration | Speedup | Quality (MT-Bench score) |
|
|
182
|
+
|---------------|---------|--------------------------|
|
|
183
|
+
| Baseline | 1.0× | 6.57 |
|
|
184
|
+
| Medusa-1 (frozen) | 2.2× | 6.57 (lossless) |
|
|
185
|
+
| Medusa-2 (joint) | 2.3× | 6.60 (+0.03) |
|
|
186
|
+
| Medusa-2 (optimized) | 3.6× | 6.55 (-0.02) |
|
|
187
|
+
|
|
188
|
+
**Key findings**:
|
|
189
|
+
- Medusa-1: No quality degradation (frozen base)
|
|
190
|
+
- Medusa-2: Slight quality improvement possible
|
|
191
|
+
- Trade-off: More aggressive = faster but may reduce quality
|
|
192
|
+
|
|
193
|
+
## Hyperparameter Tuning
|
|
194
|
+
|
|
195
|
+
### Number of Heads
|
|
196
|
+
|
|
197
|
+
```python
|
|
198
|
+
# Typical configurations:
|
|
199
|
+
num_heads = 2 # Conservative (2× speedup)
|
|
200
|
+
num_heads = 3 # Balanced (2.5× speedup)
|
|
201
|
+
num_heads = 4 # Standard (3× speedup, from paper)
|
|
202
|
+
num_heads = 5 # Aggressive (3.5×+ speedup)
|
|
203
|
+
|
|
204
|
+
# Rule: More heads = more candidates but also more computation
|
|
205
|
+
# Optimal: 3-4 heads for most models
|
|
206
|
+
```
|
|
207
|
+
|
|
208
|
+
### Top-K per Head
|
|
209
|
+
|
|
210
|
+
```python
|
|
211
|
+
# Candidates per head
|
|
212
|
+
top_k = 2 # Standard (2^num_heads total candidates)
|
|
213
|
+
top_k = 3 # More candidates (3^num_heads)
|
|
214
|
+
top_k = 5 # Many candidates (5^num_heads)
|
|
215
|
+
|
|
216
|
+
# Example with 4 heads:
|
|
217
|
+
# top_k=2: 16 candidates (fast)
|
|
218
|
+
# top_k=3: 81 candidates (slower verification)
|
|
219
|
+
```
|
|
220
|
+
|
|
221
|
+
### Tree Construction
|
|
222
|
+
|
|
223
|
+
**Medusa Choices** (which candidate paths to explore):
|
|
224
|
+
|
|
225
|
+
```python
|
|
226
|
+
# Standard configuration (from paper)
|
|
227
|
+
medusa_choices = [
|
|
228
|
+
[0], # Only head 0
|
|
229
|
+
[0, 0], # Head 0, then head 1 (first candidate)
|
|
230
|
+
[0, 1], # Head 0, then head 1 (second candidate)
|
|
231
|
+
[0, 0, 0], # All heads (first path)
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
# Aggressive configuration (more paths)
|
|
235
|
+
medusa_choices = [
|
|
236
|
+
[0],
|
|
237
|
+
[0, 0], [0, 1],
|
|
238
|
+
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
|
|
239
|
+
]
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
## Training Recipe
|
|
243
|
+
|
|
244
|
+
### Data Requirements
|
|
245
|
+
|
|
246
|
+
**Medusa-1**:
|
|
247
|
+
- Amount: 10M-100M tokens
|
|
248
|
+
- Quality: Any text corpus works
|
|
249
|
+
- Time: 2-8 hours on 8× A100
|
|
250
|
+
|
|
251
|
+
**Medusa-2**:
|
|
252
|
+
- Amount: 100M-1B tokens
|
|
253
|
+
- Quality: High-quality (same domain as target use case)
|
|
254
|
+
- Time: 1-3 days on 8× A100
|
|
255
|
+
|
|
256
|
+
### Training Script
|
|
257
|
+
|
|
258
|
+
```bash
|
|
259
|
+
# Clone Medusa repo
|
|
260
|
+
git clone https://github.com/FasterDecoding/Medusa
|
|
261
|
+
cd Medusa
|
|
262
|
+
|
|
263
|
+
# Train Medusa-1 (frozen base)
|
|
264
|
+
python medusa/train/train.py \
|
|
265
|
+
--model_name_or_path lmsys/vicuna-7b-v1.3 \
|
|
266
|
+
--data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
|
|
267
|
+
--bf16 True \
|
|
268
|
+
--output_dir medusa-vicuna-7b-v1.3 \
|
|
269
|
+
--num_train_epochs 3 \
|
|
270
|
+
--per_device_train_batch_size 4 \
|
|
271
|
+
--gradient_accumulation_steps 8 \
|
|
272
|
+
--learning_rate 1e-3 \
|
|
273
|
+
--medusa_num_heads 4 \
|
|
274
|
+
--medusa_num_layers 1 \
|
|
275
|
+
--freeze_base_model True # Medusa-1
|
|
276
|
+
|
|
277
|
+
# Train Medusa-2 (joint fine-tuning)
|
|
278
|
+
python medusa/train/train.py \
|
|
279
|
+
--model_name_or_path lmsys/vicuna-7b-v1.3 \
|
|
280
|
+
--data_path high_quality_data.json \
|
|
281
|
+
--bf16 True \
|
|
282
|
+
--output_dir medusa-vicuna-7b-v1.3-joint \
|
|
283
|
+
--num_train_epochs 1 \
|
|
284
|
+
--per_device_train_batch_size 4 \
|
|
285
|
+
--gradient_accumulation_steps 8 \
|
|
286
|
+
--learning_rate 1e-5 \ # Lower LR for base model
|
|
287
|
+
--medusa_num_heads 4 \
|
|
288
|
+
--freeze_base_model False # Medusa-2 (joint)
|
|
289
|
+
```
|
|
290
|
+
|
|
291
|
+
## Deployment
|
|
292
|
+
|
|
293
|
+
### Loading Medusa Model
|
|
294
|
+
|
|
295
|
+
```python
|
|
296
|
+
from medusa.model.medusa_model import MedusaModel
|
|
297
|
+
|
|
298
|
+
# Load pre-trained Medusa model
|
|
299
|
+
model = MedusaModel.from_pretrained(
|
|
300
|
+
"FasterDecoding/medusa-vicuna-7b-v1.3",
|
|
301
|
+
torch_dtype=torch.float16,
|
|
302
|
+
device_map="auto"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Or load base + Medusa heads separately
|
|
306
|
+
base_model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3")
|
|
307
|
+
medusa_heads = torch.load("medusa_heads.pt")
|
|
308
|
+
model = MedusaModel(base_model, medusa_heads)
|
|
309
|
+
```
|
|
310
|
+
|
|
311
|
+
### Generation
|
|
312
|
+
|
|
313
|
+
```python
|
|
314
|
+
# Generate with Medusa
|
|
315
|
+
outputs = model.medusa_generate(
|
|
316
|
+
input_ids,
|
|
317
|
+
max_new_tokens=256,
|
|
318
|
+
temperature=0.7,
|
|
319
|
+
posterior_threshold=0.09, # Acceptance threshold
|
|
320
|
+
posterior_alpha=0.3, # Tree construction parameter
|
|
321
|
+
medusa_choices=medusa_choices, # Candidate paths
|
|
322
|
+
)
|
|
323
|
+
```
|
|
324
|
+
|
|
325
|
+
## Comparison with Speculative Decoding
|
|
326
|
+
|
|
327
|
+
| Aspect | Medusa | Speculative Decoding |
|
|
328
|
+
|--------|--------|----------------------|
|
|
329
|
+
| **Draft Model** | Built-in (heads) | External (separate model) |
|
|
330
|
+
| **Training** | Minimal (heads only) | None (use existing small model) |
|
|
331
|
+
| **Memory** | Base + heads (~1-2% overhead) | Base + draft (can be large) |
|
|
332
|
+
| **Speedup** | 2-3.6× | 1.5-2× |
|
|
333
|
+
| **Deployment** | Single model | Two models |
|
|
334
|
+
|
|
335
|
+
**When to use Medusa**:
|
|
336
|
+
- Want single model deployment
|
|
337
|
+
- Can afford minimal training
|
|
338
|
+
- Need best speedup (3×+)
|
|
339
|
+
|
|
340
|
+
**When to use Speculative**:
|
|
341
|
+
- Have existing small model
|
|
342
|
+
- Zero training budget
|
|
343
|
+
- Simpler setup
|
|
344
|
+
|
|
345
|
+
## Resources
|
|
346
|
+
|
|
347
|
+
- **Paper**: https://arxiv.org/abs/2401.10774
|
|
348
|
+
- **GitHub**: https://github.com/FasterDecoding/Medusa
|
|
349
|
+
- **Blog**: https://www.together.ai/blog/medusa
|
|
350
|
+
- **Demo**: https://sites.google.com/view/medusa-llm
|