@synsci/cli-darwin-x64 1.1.49
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/skills/accelerate/SKILL.md +332 -0
- package/bin/skills/accelerate/references/custom-plugins.md +453 -0
- package/bin/skills/accelerate/references/megatron-integration.md +489 -0
- package/bin/skills/accelerate/references/performance.md +525 -0
- package/bin/skills/audiocraft/SKILL.md +564 -0
- package/bin/skills/audiocraft/references/advanced-usage.md +666 -0
- package/bin/skills/audiocraft/references/troubleshooting.md +504 -0
- package/bin/skills/autogpt/SKILL.md +403 -0
- package/bin/skills/autogpt/references/advanced-usage.md +535 -0
- package/bin/skills/autogpt/references/troubleshooting.md +420 -0
- package/bin/skills/awq/SKILL.md +310 -0
- package/bin/skills/awq/references/advanced-usage.md +324 -0
- package/bin/skills/awq/references/troubleshooting.md +344 -0
- package/bin/skills/axolotl/SKILL.md +158 -0
- package/bin/skills/axolotl/references/api.md +5548 -0
- package/bin/skills/axolotl/references/dataset-formats.md +1029 -0
- package/bin/skills/axolotl/references/index.md +15 -0
- package/bin/skills/axolotl/references/other.md +3563 -0
- package/bin/skills/bigcode-evaluation-harness/SKILL.md +405 -0
- package/bin/skills/bigcode-evaluation-harness/references/benchmarks.md +393 -0
- package/bin/skills/bigcode-evaluation-harness/references/custom-tasks.md +424 -0
- package/bin/skills/bigcode-evaluation-harness/references/issues.md +394 -0
- package/bin/skills/bitsandbytes/SKILL.md +411 -0
- package/bin/skills/bitsandbytes/references/memory-optimization.md +521 -0
- package/bin/skills/bitsandbytes/references/qlora-training.md +521 -0
- package/bin/skills/bitsandbytes/references/quantization-formats.md +447 -0
- package/bin/skills/blip-2/SKILL.md +564 -0
- package/bin/skills/blip-2/references/advanced-usage.md +680 -0
- package/bin/skills/blip-2/references/troubleshooting.md +526 -0
- package/bin/skills/chroma/SKILL.md +406 -0
- package/bin/skills/chroma/references/integration.md +38 -0
- package/bin/skills/clip/SKILL.md +253 -0
- package/bin/skills/clip/references/applications.md +207 -0
- package/bin/skills/constitutional-ai/SKILL.md +290 -0
- package/bin/skills/crewai/SKILL.md +498 -0
- package/bin/skills/crewai/references/flows.md +438 -0
- package/bin/skills/crewai/references/tools.md +429 -0
- package/bin/skills/crewai/references/troubleshooting.md +480 -0
- package/bin/skills/deepspeed/SKILL.md +141 -0
- package/bin/skills/deepspeed/references/08.md +17 -0
- package/bin/skills/deepspeed/references/09.md +173 -0
- package/bin/skills/deepspeed/references/2020.md +378 -0
- package/bin/skills/deepspeed/references/2023.md +279 -0
- package/bin/skills/deepspeed/references/assets.md +179 -0
- package/bin/skills/deepspeed/references/index.md +35 -0
- package/bin/skills/deepspeed/references/mii.md +118 -0
- package/bin/skills/deepspeed/references/other.md +1191 -0
- package/bin/skills/deepspeed/references/tutorials.md +6554 -0
- package/bin/skills/dspy/SKILL.md +590 -0
- package/bin/skills/dspy/references/examples.md +663 -0
- package/bin/skills/dspy/references/modules.md +475 -0
- package/bin/skills/dspy/references/optimizers.md +566 -0
- package/bin/skills/faiss/SKILL.md +221 -0
- package/bin/skills/faiss/references/index_types.md +280 -0
- package/bin/skills/flash-attention/SKILL.md +367 -0
- package/bin/skills/flash-attention/references/benchmarks.md +215 -0
- package/bin/skills/flash-attention/references/transformers-integration.md +293 -0
- package/bin/skills/gguf/SKILL.md +427 -0
- package/bin/skills/gguf/references/advanced-usage.md +504 -0
- package/bin/skills/gguf/references/troubleshooting.md +442 -0
- package/bin/skills/gptq/SKILL.md +450 -0
- package/bin/skills/gptq/references/calibration.md +337 -0
- package/bin/skills/gptq/references/integration.md +129 -0
- package/bin/skills/gptq/references/troubleshooting.md +95 -0
- package/bin/skills/grpo-rl-training/README.md +97 -0
- package/bin/skills/grpo-rl-training/SKILL.md +572 -0
- package/bin/skills/grpo-rl-training/examples/reward_functions_library.py +393 -0
- package/bin/skills/grpo-rl-training/templates/basic_grpo_training.py +228 -0
- package/bin/skills/guidance/SKILL.md +572 -0
- package/bin/skills/guidance/references/backends.md +554 -0
- package/bin/skills/guidance/references/constraints.md +674 -0
- package/bin/skills/guidance/references/examples.md +767 -0
- package/bin/skills/hqq/SKILL.md +445 -0
- package/bin/skills/hqq/references/advanced-usage.md +528 -0
- package/bin/skills/hqq/references/troubleshooting.md +503 -0
- package/bin/skills/hugging-face-cli/SKILL.md +191 -0
- package/bin/skills/hugging-face-cli/references/commands.md +954 -0
- package/bin/skills/hugging-face-cli/references/examples.md +374 -0
- package/bin/skills/hugging-face-datasets/SKILL.md +547 -0
- package/bin/skills/hugging-face-datasets/examples/diverse_training_examples.json +239 -0
- package/bin/skills/hugging-face-datasets/examples/system_prompt_template.txt +196 -0
- package/bin/skills/hugging-face-datasets/examples/training_examples.json +176 -0
- package/bin/skills/hugging-face-datasets/scripts/dataset_manager.py +522 -0
- package/bin/skills/hugging-face-datasets/scripts/sql_manager.py +844 -0
- package/bin/skills/hugging-face-datasets/templates/chat.json +55 -0
- package/bin/skills/hugging-face-datasets/templates/classification.json +62 -0
- package/bin/skills/hugging-face-datasets/templates/completion.json +51 -0
- package/bin/skills/hugging-face-datasets/templates/custom.json +75 -0
- package/bin/skills/hugging-face-datasets/templates/qa.json +54 -0
- package/bin/skills/hugging-face-datasets/templates/tabular.json +81 -0
- package/bin/skills/hugging-face-evaluation/SKILL.md +656 -0
- package/bin/skills/hugging-face-evaluation/examples/USAGE_EXAMPLES.md +382 -0
- package/bin/skills/hugging-face-evaluation/examples/artificial_analysis_to_hub.py +141 -0
- package/bin/skills/hugging-face-evaluation/examples/example_readme_tables.md +135 -0
- package/bin/skills/hugging-face-evaluation/examples/metric_mapping.json +50 -0
- package/bin/skills/hugging-face-evaluation/requirements.txt +20 -0
- package/bin/skills/hugging-face-evaluation/scripts/evaluation_manager.py +1374 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_eval_uv.py +104 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_vllm_uv.py +317 -0
- package/bin/skills/hugging-face-evaluation/scripts/lighteval_vllm_uv.py +303 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_eval_job.py +98 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_vllm_eval_job.py +331 -0
- package/bin/skills/hugging-face-evaluation/scripts/test_extraction.py +206 -0
- package/bin/skills/hugging-face-jobs/SKILL.md +1041 -0
- package/bin/skills/hugging-face-jobs/index.html +216 -0
- package/bin/skills/hugging-face-jobs/references/hardware_guide.md +336 -0
- package/bin/skills/hugging-face-jobs/references/hub_saving.md +352 -0
- package/bin/skills/hugging-face-jobs/references/token_usage.md +546 -0
- package/bin/skills/hugging-face-jobs/references/troubleshooting.md +475 -0
- package/bin/skills/hugging-face-jobs/scripts/cot-self-instruct.py +718 -0
- package/bin/skills/hugging-face-jobs/scripts/finepdfs-stats.py +546 -0
- package/bin/skills/hugging-face-jobs/scripts/generate-responses.py +587 -0
- package/bin/skills/hugging-face-model-trainer/SKILL.md +711 -0
- package/bin/skills/hugging-face-model-trainer/references/gguf_conversion.md +296 -0
- package/bin/skills/hugging-face-model-trainer/references/hardware_guide.md +283 -0
- package/bin/skills/hugging-face-model-trainer/references/hub_saving.md +364 -0
- package/bin/skills/hugging-face-model-trainer/references/reliability_principles.md +371 -0
- package/bin/skills/hugging-face-model-trainer/references/trackio_guide.md +189 -0
- package/bin/skills/hugging-face-model-trainer/references/training_methods.md +150 -0
- package/bin/skills/hugging-face-model-trainer/references/training_patterns.md +203 -0
- package/bin/skills/hugging-face-model-trainer/references/troubleshooting.md +282 -0
- package/bin/skills/hugging-face-model-trainer/scripts/convert_to_gguf.py +424 -0
- package/bin/skills/hugging-face-model-trainer/scripts/dataset_inspector.py +417 -0
- package/bin/skills/hugging-face-model-trainer/scripts/estimate_cost.py +150 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_dpo_example.py +106 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_grpo_example.py +89 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_sft_example.py +122 -0
- package/bin/skills/hugging-face-paper-publisher/SKILL.md +627 -0
- package/bin/skills/hugging-face-paper-publisher/examples/example_usage.md +327 -0
- package/bin/skills/hugging-face-paper-publisher/references/quick_reference.md +216 -0
- package/bin/skills/hugging-face-paper-publisher/scripts/paper_manager.py +508 -0
- package/bin/skills/hugging-face-paper-publisher/templates/arxiv.md +299 -0
- package/bin/skills/hugging-face-paper-publisher/templates/ml-report.md +358 -0
- package/bin/skills/hugging-face-paper-publisher/templates/modern.md +319 -0
- package/bin/skills/hugging-face-paper-publisher/templates/standard.md +201 -0
- package/bin/skills/hugging-face-tool-builder/SKILL.md +115 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.py +57 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.sh +40 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.tsx +57 -0
- package/bin/skills/hugging-face-tool-builder/references/find_models_by_paper.sh +230 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_enrich_models.sh +96 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_card_frontmatter.sh +188 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_papers_auth.sh +171 -0
- package/bin/skills/hugging-face-trackio/SKILL.md +65 -0
- package/bin/skills/hugging-face-trackio/references/logging_metrics.md +206 -0
- package/bin/skills/hugging-face-trackio/references/retrieving_metrics.md +223 -0
- package/bin/skills/huggingface-tokenizers/SKILL.md +516 -0
- package/bin/skills/huggingface-tokenizers/references/algorithms.md +653 -0
- package/bin/skills/huggingface-tokenizers/references/integration.md +637 -0
- package/bin/skills/huggingface-tokenizers/references/pipeline.md +723 -0
- package/bin/skills/huggingface-tokenizers/references/training.md +565 -0
- package/bin/skills/instructor/SKILL.md +740 -0
- package/bin/skills/instructor/references/examples.md +107 -0
- package/bin/skills/instructor/references/providers.md +70 -0
- package/bin/skills/instructor/references/validation.md +606 -0
- package/bin/skills/knowledge-distillation/SKILL.md +458 -0
- package/bin/skills/knowledge-distillation/references/minillm.md +334 -0
- package/bin/skills/lambda-labs/SKILL.md +545 -0
- package/bin/skills/lambda-labs/references/advanced-usage.md +611 -0
- package/bin/skills/lambda-labs/references/troubleshooting.md +530 -0
- package/bin/skills/langchain/SKILL.md +480 -0
- package/bin/skills/langchain/references/agents.md +499 -0
- package/bin/skills/langchain/references/integration.md +562 -0
- package/bin/skills/langchain/references/rag.md +600 -0
- package/bin/skills/langsmith/SKILL.md +422 -0
- package/bin/skills/langsmith/references/advanced-usage.md +548 -0
- package/bin/skills/langsmith/references/troubleshooting.md +537 -0
- package/bin/skills/litgpt/SKILL.md +469 -0
- package/bin/skills/litgpt/references/custom-models.md +568 -0
- package/bin/skills/litgpt/references/distributed-training.md +451 -0
- package/bin/skills/litgpt/references/supported-models.md +336 -0
- package/bin/skills/litgpt/references/training-recipes.md +619 -0
- package/bin/skills/llama-cpp/SKILL.md +258 -0
- package/bin/skills/llama-cpp/references/optimization.md +89 -0
- package/bin/skills/llama-cpp/references/quantization.md +213 -0
- package/bin/skills/llama-cpp/references/server.md +125 -0
- package/bin/skills/llama-factory/SKILL.md +80 -0
- package/bin/skills/llama-factory/references/_images.md +23 -0
- package/bin/skills/llama-factory/references/advanced.md +1055 -0
- package/bin/skills/llama-factory/references/getting_started.md +349 -0
- package/bin/skills/llama-factory/references/index.md +19 -0
- package/bin/skills/llama-factory/references/other.md +31 -0
- package/bin/skills/llamaguard/SKILL.md +337 -0
- package/bin/skills/llamaindex/SKILL.md +569 -0
- package/bin/skills/llamaindex/references/agents.md +83 -0
- package/bin/skills/llamaindex/references/data_connectors.md +108 -0
- package/bin/skills/llamaindex/references/query_engines.md +406 -0
- package/bin/skills/llava/SKILL.md +304 -0
- package/bin/skills/llava/references/training.md +197 -0
- package/bin/skills/lm-evaluation-harness/SKILL.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
- package/bin/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
- package/bin/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
- package/bin/skills/long-context/SKILL.md +536 -0
- package/bin/skills/long-context/references/extension_methods.md +468 -0
- package/bin/skills/long-context/references/fine_tuning.md +611 -0
- package/bin/skills/long-context/references/rope.md +402 -0
- package/bin/skills/mamba/SKILL.md +260 -0
- package/bin/skills/mamba/references/architecture-details.md +206 -0
- package/bin/skills/mamba/references/benchmarks.md +255 -0
- package/bin/skills/mamba/references/training-guide.md +388 -0
- package/bin/skills/megatron-core/SKILL.md +366 -0
- package/bin/skills/megatron-core/references/benchmarks.md +249 -0
- package/bin/skills/megatron-core/references/parallelism-guide.md +404 -0
- package/bin/skills/megatron-core/references/production-examples.md +473 -0
- package/bin/skills/megatron-core/references/training-recipes.md +547 -0
- package/bin/skills/miles/SKILL.md +315 -0
- package/bin/skills/miles/references/api-reference.md +141 -0
- package/bin/skills/miles/references/troubleshooting.md +352 -0
- package/bin/skills/mlflow/SKILL.md +704 -0
- package/bin/skills/mlflow/references/deployment.md +744 -0
- package/bin/skills/mlflow/references/model-registry.md +770 -0
- package/bin/skills/mlflow/references/tracking.md +680 -0
- package/bin/skills/modal/SKILL.md +341 -0
- package/bin/skills/modal/references/advanced-usage.md +503 -0
- package/bin/skills/modal/references/troubleshooting.md +494 -0
- package/bin/skills/model-merging/SKILL.md +539 -0
- package/bin/skills/model-merging/references/evaluation.md +462 -0
- package/bin/skills/model-merging/references/examples.md +428 -0
- package/bin/skills/model-merging/references/methods.md +352 -0
- package/bin/skills/model-pruning/SKILL.md +495 -0
- package/bin/skills/model-pruning/references/wanda.md +347 -0
- package/bin/skills/moe-training/SKILL.md +526 -0
- package/bin/skills/moe-training/references/architectures.md +432 -0
- package/bin/skills/moe-training/references/inference.md +348 -0
- package/bin/skills/moe-training/references/training.md +425 -0
- package/bin/skills/nanogpt/SKILL.md +290 -0
- package/bin/skills/nanogpt/references/architecture.md +382 -0
- package/bin/skills/nanogpt/references/data.md +476 -0
- package/bin/skills/nanogpt/references/training.md +564 -0
- package/bin/skills/nemo-curator/SKILL.md +383 -0
- package/bin/skills/nemo-curator/references/deduplication.md +87 -0
- package/bin/skills/nemo-curator/references/filtering.md +102 -0
- package/bin/skills/nemo-evaluator/SKILL.md +494 -0
- package/bin/skills/nemo-evaluator/references/adapter-system.md +340 -0
- package/bin/skills/nemo-evaluator/references/configuration.md +447 -0
- package/bin/skills/nemo-evaluator/references/custom-benchmarks.md +315 -0
- package/bin/skills/nemo-evaluator/references/execution-backends.md +361 -0
- package/bin/skills/nemo-guardrails/SKILL.md +297 -0
- package/bin/skills/nnsight/SKILL.md +436 -0
- package/bin/skills/nnsight/references/README.md +78 -0
- package/bin/skills/nnsight/references/api.md +344 -0
- package/bin/skills/nnsight/references/tutorials.md +300 -0
- package/bin/skills/openrlhf/SKILL.md +249 -0
- package/bin/skills/openrlhf/references/algorithm-comparison.md +404 -0
- package/bin/skills/openrlhf/references/custom-rewards.md +530 -0
- package/bin/skills/openrlhf/references/hybrid-engine.md +287 -0
- package/bin/skills/openrlhf/references/multi-node-training.md +454 -0
- package/bin/skills/outlines/SKILL.md +652 -0
- package/bin/skills/outlines/references/backends.md +615 -0
- package/bin/skills/outlines/references/examples.md +773 -0
- package/bin/skills/outlines/references/json_generation.md +652 -0
- package/bin/skills/peft/SKILL.md +431 -0
- package/bin/skills/peft/references/advanced-usage.md +514 -0
- package/bin/skills/peft/references/troubleshooting.md +480 -0
- package/bin/skills/phoenix/SKILL.md +475 -0
- package/bin/skills/phoenix/references/advanced-usage.md +619 -0
- package/bin/skills/phoenix/references/troubleshooting.md +538 -0
- package/bin/skills/pinecone/SKILL.md +358 -0
- package/bin/skills/pinecone/references/deployment.md +181 -0
- package/bin/skills/pytorch-fsdp/SKILL.md +126 -0
- package/bin/skills/pytorch-fsdp/references/index.md +7 -0
- package/bin/skills/pytorch-fsdp/references/other.md +4249 -0
- package/bin/skills/pytorch-lightning/SKILL.md +346 -0
- package/bin/skills/pytorch-lightning/references/callbacks.md +436 -0
- package/bin/skills/pytorch-lightning/references/distributed.md +490 -0
- package/bin/skills/pytorch-lightning/references/hyperparameter-tuning.md +556 -0
- package/bin/skills/pyvene/SKILL.md +473 -0
- package/bin/skills/pyvene/references/README.md +73 -0
- package/bin/skills/pyvene/references/api.md +383 -0
- package/bin/skills/pyvene/references/tutorials.md +376 -0
- package/bin/skills/qdrant/SKILL.md +493 -0
- package/bin/skills/qdrant/references/advanced-usage.md +648 -0
- package/bin/skills/qdrant/references/troubleshooting.md +631 -0
- package/bin/skills/ray-data/SKILL.md +326 -0
- package/bin/skills/ray-data/references/integration.md +82 -0
- package/bin/skills/ray-data/references/transformations.md +83 -0
- package/bin/skills/ray-train/SKILL.md +406 -0
- package/bin/skills/ray-train/references/multi-node.md +628 -0
- package/bin/skills/rwkv/SKILL.md +260 -0
- package/bin/skills/rwkv/references/architecture-details.md +344 -0
- package/bin/skills/rwkv/references/rwkv7.md +386 -0
- package/bin/skills/rwkv/references/state-management.md +369 -0
- package/bin/skills/saelens/SKILL.md +386 -0
- package/bin/skills/saelens/references/README.md +70 -0
- package/bin/skills/saelens/references/api.md +333 -0
- package/bin/skills/saelens/references/tutorials.md +318 -0
- package/bin/skills/segment-anything/SKILL.md +500 -0
- package/bin/skills/segment-anything/references/advanced-usage.md +589 -0
- package/bin/skills/segment-anything/references/troubleshooting.md +484 -0
- package/bin/skills/sentence-transformers/SKILL.md +255 -0
- package/bin/skills/sentence-transformers/references/models.md +123 -0
- package/bin/skills/sentencepiece/SKILL.md +235 -0
- package/bin/skills/sentencepiece/references/algorithms.md +200 -0
- package/bin/skills/sentencepiece/references/training.md +304 -0
- package/bin/skills/sglang/SKILL.md +442 -0
- package/bin/skills/sglang/references/deployment.md +490 -0
- package/bin/skills/sglang/references/radix-attention.md +413 -0
- package/bin/skills/sglang/references/structured-generation.md +541 -0
- package/bin/skills/simpo/SKILL.md +219 -0
- package/bin/skills/simpo/references/datasets.md +478 -0
- package/bin/skills/simpo/references/hyperparameters.md +452 -0
- package/bin/skills/simpo/references/loss-functions.md +350 -0
- package/bin/skills/skypilot/SKILL.md +509 -0
- package/bin/skills/skypilot/references/advanced-usage.md +491 -0
- package/bin/skills/skypilot/references/troubleshooting.md +570 -0
- package/bin/skills/slime/SKILL.md +464 -0
- package/bin/skills/slime/references/api-reference.md +392 -0
- package/bin/skills/slime/references/troubleshooting.md +386 -0
- package/bin/skills/speculative-decoding/SKILL.md +467 -0
- package/bin/skills/speculative-decoding/references/lookahead.md +309 -0
- package/bin/skills/speculative-decoding/references/medusa.md +350 -0
- package/bin/skills/stable-diffusion/SKILL.md +519 -0
- package/bin/skills/stable-diffusion/references/advanced-usage.md +716 -0
- package/bin/skills/stable-diffusion/references/troubleshooting.md +555 -0
- package/bin/skills/tensorboard/SKILL.md +629 -0
- package/bin/skills/tensorboard/references/integrations.md +638 -0
- package/bin/skills/tensorboard/references/profiling.md +545 -0
- package/bin/skills/tensorboard/references/visualization.md +620 -0
- package/bin/skills/tensorrt-llm/SKILL.md +187 -0
- package/bin/skills/tensorrt-llm/references/multi-gpu.md +298 -0
- package/bin/skills/tensorrt-llm/references/optimization.md +242 -0
- package/bin/skills/tensorrt-llm/references/serving.md +470 -0
- package/bin/skills/tinker/SKILL.md +362 -0
- package/bin/skills/tinker/references/api-reference.md +168 -0
- package/bin/skills/tinker/references/getting-started.md +157 -0
- package/bin/skills/tinker/references/loss-functions.md +163 -0
- package/bin/skills/tinker/references/models-and-lora.md +139 -0
- package/bin/skills/tinker/references/recipes.md +280 -0
- package/bin/skills/tinker/references/reinforcement-learning.md +212 -0
- package/bin/skills/tinker/references/rendering.md +243 -0
- package/bin/skills/tinker/references/supervised-learning.md +232 -0
- package/bin/skills/tinker-training-cost/SKILL.md +187 -0
- package/bin/skills/tinker-training-cost/scripts/calculate_cost.py +123 -0
- package/bin/skills/torchforge/SKILL.md +433 -0
- package/bin/skills/torchforge/references/api-reference.md +327 -0
- package/bin/skills/torchforge/references/troubleshooting.md +409 -0
- package/bin/skills/torchtitan/SKILL.md +358 -0
- package/bin/skills/torchtitan/references/checkpoint.md +181 -0
- package/bin/skills/torchtitan/references/custom-models.md +258 -0
- package/bin/skills/torchtitan/references/float8.md +133 -0
- package/bin/skills/torchtitan/references/fsdp.md +126 -0
- package/bin/skills/transformer-lens/SKILL.md +346 -0
- package/bin/skills/transformer-lens/references/README.md +54 -0
- package/bin/skills/transformer-lens/references/api.md +362 -0
- package/bin/skills/transformer-lens/references/tutorials.md +339 -0
- package/bin/skills/trl-fine-tuning/SKILL.md +455 -0
- package/bin/skills/trl-fine-tuning/references/dpo-variants.md +227 -0
- package/bin/skills/trl-fine-tuning/references/online-rl.md +82 -0
- package/bin/skills/trl-fine-tuning/references/reward-modeling.md +122 -0
- package/bin/skills/trl-fine-tuning/references/sft-training.md +168 -0
- package/bin/skills/unsloth/SKILL.md +80 -0
- package/bin/skills/unsloth/references/index.md +7 -0
- package/bin/skills/unsloth/references/llms-full.md +16799 -0
- package/bin/skills/unsloth/references/llms-txt.md +12044 -0
- package/bin/skills/unsloth/references/llms.md +82 -0
- package/bin/skills/verl/SKILL.md +391 -0
- package/bin/skills/verl/references/api-reference.md +301 -0
- package/bin/skills/verl/references/troubleshooting.md +391 -0
- package/bin/skills/vllm/SKILL.md +364 -0
- package/bin/skills/vllm/references/optimization.md +226 -0
- package/bin/skills/vllm/references/quantization.md +284 -0
- package/bin/skills/vllm/references/server-deployment.md +255 -0
- package/bin/skills/vllm/references/troubleshooting.md +447 -0
- package/bin/skills/weights-and-biases/SKILL.md +590 -0
- package/bin/skills/weights-and-biases/references/artifacts.md +584 -0
- package/bin/skills/weights-and-biases/references/integrations.md +700 -0
- package/bin/skills/weights-and-biases/references/sweeps.md +847 -0
- package/bin/skills/whisper/SKILL.md +317 -0
- package/bin/skills/whisper/references/languages.md +189 -0
- package/bin/synsc +0 -0
- package/package.json +10 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# Adding Custom Models to TorchTitan
|
|
2
|
+
|
|
3
|
+
This guide explains how to add a new model to TorchTitan following the established patterns.
|
|
4
|
+
|
|
5
|
+
## Directory Structure
|
|
6
|
+
|
|
7
|
+
```
|
|
8
|
+
torchtitan/models/your_model/
|
|
9
|
+
├── model/
|
|
10
|
+
│ ├── __init__.py
|
|
11
|
+
│ ├── args.py # Model arguments
|
|
12
|
+
│ ├── model.py # Model definition
|
|
13
|
+
│ └── state_dict_adapter.py # HF conversion (optional)
|
|
14
|
+
├── infra/
|
|
15
|
+
│ ├── __init__.py
|
|
16
|
+
│ ├── parallelize.py # TP, FSDP, compile application
|
|
17
|
+
│ └── pipeline.py # PP application (optional)
|
|
18
|
+
├── train_configs/
|
|
19
|
+
│ ├── debug_model.toml
|
|
20
|
+
│ └── your_model_XB.toml
|
|
21
|
+
├── __init__.py # TrainSpec registration
|
|
22
|
+
└── README.md
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
## Step 1: Define Model Arguments
|
|
26
|
+
|
|
27
|
+
Inherit from `BaseModelArgs`:
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
# model/args.py
|
|
31
|
+
from torchtitan.protocols.model import BaseModelArgs
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class YourModelArgs(BaseModelArgs):
|
|
36
|
+
dim: int = 4096
|
|
37
|
+
n_layers: int = 32
|
|
38
|
+
n_heads: int = 32
|
|
39
|
+
vocab_size: int = 128256
|
|
40
|
+
|
|
41
|
+
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
|
|
42
|
+
"""Return (num_params, flops_per_token) for throughput calculation."""
|
|
43
|
+
nparams = self.vocab_size * self.dim + ... # Calculate params
|
|
44
|
+
flops = 6 * nparams # Approximate: 6 * params for forward+backward
|
|
45
|
+
return nparams, flops
|
|
46
|
+
|
|
47
|
+
def update_from_config(self, job_config) -> "YourModelArgs":
|
|
48
|
+
"""Update args from training config."""
|
|
49
|
+
# Override specific args from job_config if needed
|
|
50
|
+
return self
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Step 2: Define Model
|
|
54
|
+
|
|
55
|
+
Inherit from `ModelProtocol`:
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
# model/model.py
|
|
59
|
+
import torch.nn as nn
|
|
60
|
+
from torchtitan.protocols.model import ModelProtocol
|
|
61
|
+
from .args import YourModelArgs
|
|
62
|
+
|
|
63
|
+
class YourModel(ModelProtocol):
|
|
64
|
+
def __init__(self, args: YourModelArgs):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.args = args
|
|
67
|
+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
|
68
|
+
self.layers = nn.ModuleDict({
|
|
69
|
+
str(i): TransformerBlock(args) for i in range(args.n_layers)
|
|
70
|
+
})
|
|
71
|
+
self.norm = RMSNorm(args.dim)
|
|
72
|
+
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
|
73
|
+
|
|
74
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
h = self.tok_embeddings(tokens)
|
|
76
|
+
for layer in self.layers.values():
|
|
77
|
+
h = layer(h)
|
|
78
|
+
h = self.norm(h)
|
|
79
|
+
return self.output(h)
|
|
80
|
+
|
|
81
|
+
def init_weights(self):
|
|
82
|
+
"""Initialize weights recursively."""
|
|
83
|
+
for module in self.modules():
|
|
84
|
+
if hasattr(module, 'init_weights') and module is not self:
|
|
85
|
+
module.init_weights()
|
|
86
|
+
elif isinstance(module, nn.Linear):
|
|
87
|
+
nn.init.normal_(module.weight, std=0.02)
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
**Important guidelines**:
|
|
91
|
+
- Write single-device model code (parallelism applied externally)
|
|
92
|
+
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
|
|
93
|
+
- Make input/output layers optional for PP compatibility
|
|
94
|
+
- Define `init_weights()` recursively
|
|
95
|
+
|
|
96
|
+
## Step 3: Parallelize Function
|
|
97
|
+
|
|
98
|
+
```python
|
|
99
|
+
# infra/parallelize.py
|
|
100
|
+
from torch.distributed._composable.fsdp import fully_shard
|
|
101
|
+
from torch.distributed.tensor.parallel import parallelize_module
|
|
102
|
+
|
|
103
|
+
def parallelize_your_model(
|
|
104
|
+
model: YourModel,
|
|
105
|
+
world_mesh: DeviceMesh,
|
|
106
|
+
parallel_dims: ParallelDims,
|
|
107
|
+
job_config: JobConfig,
|
|
108
|
+
):
|
|
109
|
+
# Apply in this order: TP -> AC -> compile -> FSDP
|
|
110
|
+
|
|
111
|
+
# 1. Tensor Parallelism
|
|
112
|
+
if parallel_dims.tp_enabled:
|
|
113
|
+
apply_tp(model, world_mesh["tp"], job_config)
|
|
114
|
+
|
|
115
|
+
# 2. Activation Checkpointing
|
|
116
|
+
if job_config.activation_checkpoint.mode == "full":
|
|
117
|
+
apply_ac(model, job_config)
|
|
118
|
+
|
|
119
|
+
# 3. torch.compile
|
|
120
|
+
if job_config.compile.enable:
|
|
121
|
+
model = torch.compile(model)
|
|
122
|
+
|
|
123
|
+
# 4. FSDP
|
|
124
|
+
if parallel_dims.dp_enabled:
|
|
125
|
+
apply_fsdp(model, world_mesh["dp"], job_config)
|
|
126
|
+
|
|
127
|
+
return model
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
## Step 4: Create TrainSpec
|
|
131
|
+
|
|
132
|
+
```python
|
|
133
|
+
# __init__.py
|
|
134
|
+
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
|
|
135
|
+
from .model.model import YourModel
|
|
136
|
+
from .model.args import YourModelArgs
|
|
137
|
+
from .infra.parallelize import parallelize_your_model
|
|
138
|
+
|
|
139
|
+
MODEL_CONFIGS = {
|
|
140
|
+
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
|
|
141
|
+
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
def get_train_spec(flavor: str) -> TrainSpec:
|
|
145
|
+
return TrainSpec(
|
|
146
|
+
model_cls=YourModel,
|
|
147
|
+
model_args=MODEL_CONFIGS[flavor],
|
|
148
|
+
parallelize_fn=parallelize_your_model,
|
|
149
|
+
pipeline_fn=None, # Or your_pipeline_fn for PP
|
|
150
|
+
build_optimizer_fn=build_optimizer, # Reuse existing
|
|
151
|
+
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
|
|
152
|
+
build_dataloader_fn=build_dataloader, # Reuse existing
|
|
153
|
+
build_tokenizer_fn=build_tokenizer, # Reuse existing
|
|
154
|
+
build_loss_fn=build_loss, # Reuse existing
|
|
155
|
+
state_dict_adapter=None, # Or YourStateDictAdapter
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Register so train.py can find it
|
|
159
|
+
register_train_spec("your_model", get_train_spec)
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
## Step 5: State Dict Adapter (Optional)
|
|
163
|
+
|
|
164
|
+
For HuggingFace checkpoint conversion:
|
|
165
|
+
|
|
166
|
+
```python
|
|
167
|
+
# model/state_dict_adapter.py
|
|
168
|
+
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
|
|
169
|
+
|
|
170
|
+
class YourStateDictAdapter(BaseStateDictAdapter):
|
|
171
|
+
def to_hf(self, state_dict: dict) -> dict:
|
|
172
|
+
"""Convert torchtitan state dict to HF format."""
|
|
173
|
+
hf_state_dict = {}
|
|
174
|
+
for key, value in state_dict.items():
|
|
175
|
+
hf_key = self._convert_key_to_hf(key)
|
|
176
|
+
hf_state_dict[hf_key] = value
|
|
177
|
+
return hf_state_dict
|
|
178
|
+
|
|
179
|
+
def from_hf(self, state_dict: dict) -> dict:
|
|
180
|
+
"""Convert HF state dict to torchtitan format."""
|
|
181
|
+
tt_state_dict = {}
|
|
182
|
+
for key, value in state_dict.items():
|
|
183
|
+
tt_key = self._convert_key_from_hf(key)
|
|
184
|
+
tt_state_dict[tt_key] = value
|
|
185
|
+
return tt_state_dict
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
## Step 6: Training Config
|
|
189
|
+
|
|
190
|
+
```toml
|
|
191
|
+
# train_configs/your_model_8b.toml
|
|
192
|
+
[job]
|
|
193
|
+
dump_folder = "./outputs"
|
|
194
|
+
description = "Your Model 8B training"
|
|
195
|
+
|
|
196
|
+
[model]
|
|
197
|
+
name = "your_model"
|
|
198
|
+
flavor = "8B"
|
|
199
|
+
|
|
200
|
+
[optimizer]
|
|
201
|
+
name = "AdamW"
|
|
202
|
+
lr = 3e-4
|
|
203
|
+
|
|
204
|
+
[training]
|
|
205
|
+
local_batch_size = 2
|
|
206
|
+
seq_len = 8192
|
|
207
|
+
steps = 1000
|
|
208
|
+
dataset = "c4"
|
|
209
|
+
|
|
210
|
+
[parallelism]
|
|
211
|
+
data_parallel_shard_degree = -1
|
|
212
|
+
tensor_parallel_degree = 1
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
## Step 7: Register Model
|
|
216
|
+
|
|
217
|
+
Add to `torchtitan/models/__init__.py`:
|
|
218
|
+
|
|
219
|
+
```python
|
|
220
|
+
from .your_model import get_train_spec as get_your_model_train_spec
|
|
221
|
+
|
|
222
|
+
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
|
|
223
|
+
```
|
|
224
|
+
|
|
225
|
+
## Testing
|
|
226
|
+
|
|
227
|
+
### Numerics Test
|
|
228
|
+
|
|
229
|
+
Compare output with HuggingFace implementation:
|
|
230
|
+
|
|
231
|
+
```python
|
|
232
|
+
def test_numerics():
|
|
233
|
+
# Load same checkpoint into both implementations
|
|
234
|
+
tt_model = YourModel(args).load_checkpoint(...)
|
|
235
|
+
hf_model = HFYourModel.from_pretrained(...)
|
|
236
|
+
|
|
237
|
+
# Compare outputs
|
|
238
|
+
input_ids = torch.randint(0, vocab_size, (1, 128))
|
|
239
|
+
tt_output = tt_model(input_ids)
|
|
240
|
+
hf_output = hf_model(input_ids).logits
|
|
241
|
+
|
|
242
|
+
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
|
|
243
|
+
```
|
|
244
|
+
|
|
245
|
+
### Loss Convergence
|
|
246
|
+
|
|
247
|
+
Compare loss curves with verified baseline (see `docs/converging.md`).
|
|
248
|
+
|
|
249
|
+
### Performance Benchmark
|
|
250
|
+
|
|
251
|
+
Add benchmark config to `benchmarks/` folder.
|
|
252
|
+
|
|
253
|
+
## Guiding Principles
|
|
254
|
+
|
|
255
|
+
1. **Readability over flexibility**: Don't over-abstract
|
|
256
|
+
2. **Minimal model changes**: Parallelism applied externally
|
|
257
|
+
3. **Clean, minimal codebase**: Reuse existing components where possible
|
|
258
|
+
4. **Single-device semantics**: Model code should work on single GPU
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Float8 Training in TorchTitan
|
|
2
|
+
|
|
3
|
+
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
|
|
4
|
+
|
|
5
|
+
## Hardware Requirements
|
|
6
|
+
|
|
7
|
+
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
|
|
8
|
+
- Blackwell GPUs for MXFP8 training
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Usage: Tensorwise Scaling
|
|
17
|
+
|
|
18
|
+
Standard Float8 with tensorwise dynamic scaling:
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
|
22
|
+
--model.converters="quantize.linear.float8" \
|
|
23
|
+
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
|
24
|
+
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
|
|
25
|
+
--compile.enable
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
### Key Arguments
|
|
29
|
+
|
|
30
|
+
| Argument | Description |
|
|
31
|
+
|----------|-------------|
|
|
32
|
+
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
|
|
33
|
+
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
|
|
34
|
+
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
|
|
35
|
+
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
|
|
36
|
+
|
|
37
|
+
## Usage: Rowwise Scaling
|
|
38
|
+
|
|
39
|
+
Higher accuracy than tensorwise scaling:
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
|
43
|
+
--model.converters="quantize.linear.float8" \
|
|
44
|
+
--quantize.linear.float8.recipe_name rowwise \
|
|
45
|
+
--compile.enable
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Filtering Layers
|
|
49
|
+
|
|
50
|
+
Not all layers benefit from Float8. Filter small layers:
|
|
51
|
+
|
|
52
|
+
```bash
|
|
53
|
+
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
### Auto-filtering
|
|
57
|
+
|
|
58
|
+
Automatically skip layers too small to benefit:
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Thresholds based on H100 microbenchmarks where speedup > overhead.
|
|
65
|
+
|
|
66
|
+
## TOML Configuration
|
|
67
|
+
|
|
68
|
+
```toml
|
|
69
|
+
[model]
|
|
70
|
+
converters = ["quantize.linear.float8"]
|
|
71
|
+
|
|
72
|
+
[quantize.linear.float8]
|
|
73
|
+
enable_fsdp_float8_all_gather = true
|
|
74
|
+
precompute_float8_dynamic_scale_for_fsdp = true
|
|
75
|
+
filter_fqns = ["output", "auto_filter_small_kn"]
|
|
76
|
+
|
|
77
|
+
[compile]
|
|
78
|
+
enable = true
|
|
79
|
+
components = ["model", "loss"]
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## How Float8 Works with Distributed Training
|
|
83
|
+
|
|
84
|
+
### Single Device
|
|
85
|
+
|
|
86
|
+
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
# Float8 matmul requires scales
|
|
90
|
+
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### FSDP + Float8
|
|
94
|
+
|
|
95
|
+
1. Cast sharded high-precision weights (1/N per rank) to float8
|
|
96
|
+
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
|
|
97
|
+
3. Communicate `max(abs)` across ranks for scale computation
|
|
98
|
+
4. At forward start, have unsharded float8 weights ready
|
|
99
|
+
|
|
100
|
+
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
|
|
101
|
+
|
|
102
|
+
### TP + Float8
|
|
103
|
+
|
|
104
|
+
- **Input**: Cast sharded input to float8, all-gather in float8
|
|
105
|
+
- **Weights**: Communicate `max(abs)` for sharded weights
|
|
106
|
+
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
|
|
107
|
+
|
|
108
|
+
## Scaling Strategies
|
|
109
|
+
|
|
110
|
+
| Strategy | Status | Description |
|
|
111
|
+
|----------|--------|-------------|
|
|
112
|
+
| Tensorwise dynamic | Stable | Single scale per tensor |
|
|
113
|
+
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
|
|
114
|
+
|
|
115
|
+
## Performance Gains
|
|
116
|
+
|
|
117
|
+
From benchmarks on H100:
|
|
118
|
+
|
|
119
|
+
| Configuration | TPS/GPU | vs Baseline |
|
|
120
|
+
|---------------|---------|-------------|
|
|
121
|
+
| FSDP only | 5,762 | - |
|
|
122
|
+
| FSDP + compile | 6,667 | +16% |
|
|
123
|
+
| FSDP + compile + Float8 | 8,532 | +48% |
|
|
124
|
+
|
|
125
|
+
## Determining Float8 Benefit
|
|
126
|
+
|
|
127
|
+
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
|
|
128
|
+
|
|
129
|
+
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
|
|
130
|
+
|
|
131
|
+
## MXFP8 Training (Blackwell)
|
|
132
|
+
|
|
133
|
+
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# FSDP2 in TorchTitan
|
|
2
|
+
|
|
3
|
+
## Why FSDP2?
|
|
4
|
+
|
|
5
|
+
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
|
|
6
|
+
|
|
7
|
+
### Key improvements over FSDP1
|
|
8
|
+
|
|
9
|
+
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
|
|
10
|
+
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
|
|
11
|
+
- **Simplified API**: Fewer arguments, no wrapper class
|
|
12
|
+
|
|
13
|
+
### Performance
|
|
14
|
+
|
|
15
|
+
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
|
|
16
|
+
|
|
17
|
+
## API Reference
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
|
|
21
|
+
|
|
22
|
+
@contract(state_cls=FSDPState)
|
|
23
|
+
def fully_shard(
|
|
24
|
+
module: nn.Module,
|
|
25
|
+
*,
|
|
26
|
+
mesh: Optional[DeviceMesh] = None,
|
|
27
|
+
reshard_after_forward: Union[bool, int] = True,
|
|
28
|
+
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
|
29
|
+
offload_policy: OffloadPolicy = OffloadPolicy(),
|
|
30
|
+
) -> nn.Module:
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Sharding Strategies (ZeRO Equivalents)
|
|
34
|
+
|
|
35
|
+
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|
|
36
|
+
|---------------------|------------------|-----------|
|
|
37
|
+
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
|
|
38
|
+
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
|
|
39
|
+
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
|
|
40
|
+
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
|
|
41
|
+
|
|
42
|
+
## Meta-Device Initialization
|
|
43
|
+
|
|
44
|
+
FSDP2 supports materializing tensors onto GPU _after_ sharding:
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
# Initialize on meta device (no memory)
|
|
48
|
+
with torch.device("meta"):
|
|
49
|
+
model = Transformer()
|
|
50
|
+
|
|
51
|
+
# Apply FSDP2 sharding
|
|
52
|
+
for module in model.modules():
|
|
53
|
+
if isinstance(module, TransformerBlock):
|
|
54
|
+
fully_shard(module)
|
|
55
|
+
fully_shard(model)
|
|
56
|
+
|
|
57
|
+
# Parameters still on meta device
|
|
58
|
+
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
|
59
|
+
assert tensor.device == torch.device("meta")
|
|
60
|
+
|
|
61
|
+
# Allocate sharded parameters on GPU
|
|
62
|
+
model.to_empty(device="cuda")
|
|
63
|
+
|
|
64
|
+
# Initialize weights
|
|
65
|
+
model.init_weights()
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## State Dict Differences
|
|
69
|
+
|
|
70
|
+
| Operation | FSDP1 | FSDP2 |
|
|
71
|
+
|-----------|-------|-------|
|
|
72
|
+
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
|
|
73
|
+
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
|
|
74
|
+
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
|
|
75
|
+
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
|
|
76
|
+
|
|
77
|
+
## Mixed Precision
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
|
81
|
+
|
|
82
|
+
mp_policy = MixedPrecisionPolicy(
|
|
83
|
+
param_dtype=torch.bfloat16,
|
|
84
|
+
reduce_dtype=torch.float32,
|
|
85
|
+
output_dtype=torch.bfloat16,
|
|
86
|
+
cast_forward_inputs=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
fully_shard(model, mp_policy=mp_policy)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## HSDP (Hybrid Sharded Data Parallel)
|
|
93
|
+
|
|
94
|
+
For 2D parallelism with replication + sharding:
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
from torch.distributed.device_mesh import init_device_mesh
|
|
98
|
+
|
|
99
|
+
# Replicate across 4 groups, shard within 8 GPUs each
|
|
100
|
+
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
|
|
101
|
+
|
|
102
|
+
fully_shard(model, mesh=mesh)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
## Configuration in TorchTitan
|
|
106
|
+
|
|
107
|
+
```toml
|
|
108
|
+
[parallelism]
|
|
109
|
+
# FSDP sharding degree (-1 = auto, use all available GPUs)
|
|
110
|
+
data_parallel_shard_degree = -1
|
|
111
|
+
|
|
112
|
+
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
|
|
113
|
+
data_parallel_replicate_degree = 1
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
## Removed Arguments from FSDP1
|
|
117
|
+
|
|
118
|
+
These FSDP1 arguments are no longer needed:
|
|
119
|
+
|
|
120
|
+
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
|
|
121
|
+
- `backward_prefetch`: Always uses BACKWARD_PRE
|
|
122
|
+
- `param_init_fn`: Use meta-device initialization
|
|
123
|
+
- `device_id`: Uses mesh's device automatically
|
|
124
|
+
- `sync_module_states`: Not needed with DTensor
|
|
125
|
+
- `limit_all_gathers`: New memory management doesn't need it
|
|
126
|
+
- `use_orig_params`: Always true (no FlatParameter)
|