@synsci/cli-darwin-x64 1.1.49
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/skills/accelerate/SKILL.md +332 -0
- package/bin/skills/accelerate/references/custom-plugins.md +453 -0
- package/bin/skills/accelerate/references/megatron-integration.md +489 -0
- package/bin/skills/accelerate/references/performance.md +525 -0
- package/bin/skills/audiocraft/SKILL.md +564 -0
- package/bin/skills/audiocraft/references/advanced-usage.md +666 -0
- package/bin/skills/audiocraft/references/troubleshooting.md +504 -0
- package/bin/skills/autogpt/SKILL.md +403 -0
- package/bin/skills/autogpt/references/advanced-usage.md +535 -0
- package/bin/skills/autogpt/references/troubleshooting.md +420 -0
- package/bin/skills/awq/SKILL.md +310 -0
- package/bin/skills/awq/references/advanced-usage.md +324 -0
- package/bin/skills/awq/references/troubleshooting.md +344 -0
- package/bin/skills/axolotl/SKILL.md +158 -0
- package/bin/skills/axolotl/references/api.md +5548 -0
- package/bin/skills/axolotl/references/dataset-formats.md +1029 -0
- package/bin/skills/axolotl/references/index.md +15 -0
- package/bin/skills/axolotl/references/other.md +3563 -0
- package/bin/skills/bigcode-evaluation-harness/SKILL.md +405 -0
- package/bin/skills/bigcode-evaluation-harness/references/benchmarks.md +393 -0
- package/bin/skills/bigcode-evaluation-harness/references/custom-tasks.md +424 -0
- package/bin/skills/bigcode-evaluation-harness/references/issues.md +394 -0
- package/bin/skills/bitsandbytes/SKILL.md +411 -0
- package/bin/skills/bitsandbytes/references/memory-optimization.md +521 -0
- package/bin/skills/bitsandbytes/references/qlora-training.md +521 -0
- package/bin/skills/bitsandbytes/references/quantization-formats.md +447 -0
- package/bin/skills/blip-2/SKILL.md +564 -0
- package/bin/skills/blip-2/references/advanced-usage.md +680 -0
- package/bin/skills/blip-2/references/troubleshooting.md +526 -0
- package/bin/skills/chroma/SKILL.md +406 -0
- package/bin/skills/chroma/references/integration.md +38 -0
- package/bin/skills/clip/SKILL.md +253 -0
- package/bin/skills/clip/references/applications.md +207 -0
- package/bin/skills/constitutional-ai/SKILL.md +290 -0
- package/bin/skills/crewai/SKILL.md +498 -0
- package/bin/skills/crewai/references/flows.md +438 -0
- package/bin/skills/crewai/references/tools.md +429 -0
- package/bin/skills/crewai/references/troubleshooting.md +480 -0
- package/bin/skills/deepspeed/SKILL.md +141 -0
- package/bin/skills/deepspeed/references/08.md +17 -0
- package/bin/skills/deepspeed/references/09.md +173 -0
- package/bin/skills/deepspeed/references/2020.md +378 -0
- package/bin/skills/deepspeed/references/2023.md +279 -0
- package/bin/skills/deepspeed/references/assets.md +179 -0
- package/bin/skills/deepspeed/references/index.md +35 -0
- package/bin/skills/deepspeed/references/mii.md +118 -0
- package/bin/skills/deepspeed/references/other.md +1191 -0
- package/bin/skills/deepspeed/references/tutorials.md +6554 -0
- package/bin/skills/dspy/SKILL.md +590 -0
- package/bin/skills/dspy/references/examples.md +663 -0
- package/bin/skills/dspy/references/modules.md +475 -0
- package/bin/skills/dspy/references/optimizers.md +566 -0
- package/bin/skills/faiss/SKILL.md +221 -0
- package/bin/skills/faiss/references/index_types.md +280 -0
- package/bin/skills/flash-attention/SKILL.md +367 -0
- package/bin/skills/flash-attention/references/benchmarks.md +215 -0
- package/bin/skills/flash-attention/references/transformers-integration.md +293 -0
- package/bin/skills/gguf/SKILL.md +427 -0
- package/bin/skills/gguf/references/advanced-usage.md +504 -0
- package/bin/skills/gguf/references/troubleshooting.md +442 -0
- package/bin/skills/gptq/SKILL.md +450 -0
- package/bin/skills/gptq/references/calibration.md +337 -0
- package/bin/skills/gptq/references/integration.md +129 -0
- package/bin/skills/gptq/references/troubleshooting.md +95 -0
- package/bin/skills/grpo-rl-training/README.md +97 -0
- package/bin/skills/grpo-rl-training/SKILL.md +572 -0
- package/bin/skills/grpo-rl-training/examples/reward_functions_library.py +393 -0
- package/bin/skills/grpo-rl-training/templates/basic_grpo_training.py +228 -0
- package/bin/skills/guidance/SKILL.md +572 -0
- package/bin/skills/guidance/references/backends.md +554 -0
- package/bin/skills/guidance/references/constraints.md +674 -0
- package/bin/skills/guidance/references/examples.md +767 -0
- package/bin/skills/hqq/SKILL.md +445 -0
- package/bin/skills/hqq/references/advanced-usage.md +528 -0
- package/bin/skills/hqq/references/troubleshooting.md +503 -0
- package/bin/skills/hugging-face-cli/SKILL.md +191 -0
- package/bin/skills/hugging-face-cli/references/commands.md +954 -0
- package/bin/skills/hugging-face-cli/references/examples.md +374 -0
- package/bin/skills/hugging-face-datasets/SKILL.md +547 -0
- package/bin/skills/hugging-face-datasets/examples/diverse_training_examples.json +239 -0
- package/bin/skills/hugging-face-datasets/examples/system_prompt_template.txt +196 -0
- package/bin/skills/hugging-face-datasets/examples/training_examples.json +176 -0
- package/bin/skills/hugging-face-datasets/scripts/dataset_manager.py +522 -0
- package/bin/skills/hugging-face-datasets/scripts/sql_manager.py +844 -0
- package/bin/skills/hugging-face-datasets/templates/chat.json +55 -0
- package/bin/skills/hugging-face-datasets/templates/classification.json +62 -0
- package/bin/skills/hugging-face-datasets/templates/completion.json +51 -0
- package/bin/skills/hugging-face-datasets/templates/custom.json +75 -0
- package/bin/skills/hugging-face-datasets/templates/qa.json +54 -0
- package/bin/skills/hugging-face-datasets/templates/tabular.json +81 -0
- package/bin/skills/hugging-face-evaluation/SKILL.md +656 -0
- package/bin/skills/hugging-face-evaluation/examples/USAGE_EXAMPLES.md +382 -0
- package/bin/skills/hugging-face-evaluation/examples/artificial_analysis_to_hub.py +141 -0
- package/bin/skills/hugging-face-evaluation/examples/example_readme_tables.md +135 -0
- package/bin/skills/hugging-face-evaluation/examples/metric_mapping.json +50 -0
- package/bin/skills/hugging-face-evaluation/requirements.txt +20 -0
- package/bin/skills/hugging-face-evaluation/scripts/evaluation_manager.py +1374 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_eval_uv.py +104 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_vllm_uv.py +317 -0
- package/bin/skills/hugging-face-evaluation/scripts/lighteval_vllm_uv.py +303 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_eval_job.py +98 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_vllm_eval_job.py +331 -0
- package/bin/skills/hugging-face-evaluation/scripts/test_extraction.py +206 -0
- package/bin/skills/hugging-face-jobs/SKILL.md +1041 -0
- package/bin/skills/hugging-face-jobs/index.html +216 -0
- package/bin/skills/hugging-face-jobs/references/hardware_guide.md +336 -0
- package/bin/skills/hugging-face-jobs/references/hub_saving.md +352 -0
- package/bin/skills/hugging-face-jobs/references/token_usage.md +546 -0
- package/bin/skills/hugging-face-jobs/references/troubleshooting.md +475 -0
- package/bin/skills/hugging-face-jobs/scripts/cot-self-instruct.py +718 -0
- package/bin/skills/hugging-face-jobs/scripts/finepdfs-stats.py +546 -0
- package/bin/skills/hugging-face-jobs/scripts/generate-responses.py +587 -0
- package/bin/skills/hugging-face-model-trainer/SKILL.md +711 -0
- package/bin/skills/hugging-face-model-trainer/references/gguf_conversion.md +296 -0
- package/bin/skills/hugging-face-model-trainer/references/hardware_guide.md +283 -0
- package/bin/skills/hugging-face-model-trainer/references/hub_saving.md +364 -0
- package/bin/skills/hugging-face-model-trainer/references/reliability_principles.md +371 -0
- package/bin/skills/hugging-face-model-trainer/references/trackio_guide.md +189 -0
- package/bin/skills/hugging-face-model-trainer/references/training_methods.md +150 -0
- package/bin/skills/hugging-face-model-trainer/references/training_patterns.md +203 -0
- package/bin/skills/hugging-face-model-trainer/references/troubleshooting.md +282 -0
- package/bin/skills/hugging-face-model-trainer/scripts/convert_to_gguf.py +424 -0
- package/bin/skills/hugging-face-model-trainer/scripts/dataset_inspector.py +417 -0
- package/bin/skills/hugging-face-model-trainer/scripts/estimate_cost.py +150 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_dpo_example.py +106 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_grpo_example.py +89 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_sft_example.py +122 -0
- package/bin/skills/hugging-face-paper-publisher/SKILL.md +627 -0
- package/bin/skills/hugging-face-paper-publisher/examples/example_usage.md +327 -0
- package/bin/skills/hugging-face-paper-publisher/references/quick_reference.md +216 -0
- package/bin/skills/hugging-face-paper-publisher/scripts/paper_manager.py +508 -0
- package/bin/skills/hugging-face-paper-publisher/templates/arxiv.md +299 -0
- package/bin/skills/hugging-face-paper-publisher/templates/ml-report.md +358 -0
- package/bin/skills/hugging-face-paper-publisher/templates/modern.md +319 -0
- package/bin/skills/hugging-face-paper-publisher/templates/standard.md +201 -0
- package/bin/skills/hugging-face-tool-builder/SKILL.md +115 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.py +57 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.sh +40 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.tsx +57 -0
- package/bin/skills/hugging-face-tool-builder/references/find_models_by_paper.sh +230 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_enrich_models.sh +96 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_card_frontmatter.sh +188 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_papers_auth.sh +171 -0
- package/bin/skills/hugging-face-trackio/SKILL.md +65 -0
- package/bin/skills/hugging-face-trackio/references/logging_metrics.md +206 -0
- package/bin/skills/hugging-face-trackio/references/retrieving_metrics.md +223 -0
- package/bin/skills/huggingface-tokenizers/SKILL.md +516 -0
- package/bin/skills/huggingface-tokenizers/references/algorithms.md +653 -0
- package/bin/skills/huggingface-tokenizers/references/integration.md +637 -0
- package/bin/skills/huggingface-tokenizers/references/pipeline.md +723 -0
- package/bin/skills/huggingface-tokenizers/references/training.md +565 -0
- package/bin/skills/instructor/SKILL.md +740 -0
- package/bin/skills/instructor/references/examples.md +107 -0
- package/bin/skills/instructor/references/providers.md +70 -0
- package/bin/skills/instructor/references/validation.md +606 -0
- package/bin/skills/knowledge-distillation/SKILL.md +458 -0
- package/bin/skills/knowledge-distillation/references/minillm.md +334 -0
- package/bin/skills/lambda-labs/SKILL.md +545 -0
- package/bin/skills/lambda-labs/references/advanced-usage.md +611 -0
- package/bin/skills/lambda-labs/references/troubleshooting.md +530 -0
- package/bin/skills/langchain/SKILL.md +480 -0
- package/bin/skills/langchain/references/agents.md +499 -0
- package/bin/skills/langchain/references/integration.md +562 -0
- package/bin/skills/langchain/references/rag.md +600 -0
- package/bin/skills/langsmith/SKILL.md +422 -0
- package/bin/skills/langsmith/references/advanced-usage.md +548 -0
- package/bin/skills/langsmith/references/troubleshooting.md +537 -0
- package/bin/skills/litgpt/SKILL.md +469 -0
- package/bin/skills/litgpt/references/custom-models.md +568 -0
- package/bin/skills/litgpt/references/distributed-training.md +451 -0
- package/bin/skills/litgpt/references/supported-models.md +336 -0
- package/bin/skills/litgpt/references/training-recipes.md +619 -0
- package/bin/skills/llama-cpp/SKILL.md +258 -0
- package/bin/skills/llama-cpp/references/optimization.md +89 -0
- package/bin/skills/llama-cpp/references/quantization.md +213 -0
- package/bin/skills/llama-cpp/references/server.md +125 -0
- package/bin/skills/llama-factory/SKILL.md +80 -0
- package/bin/skills/llama-factory/references/_images.md +23 -0
- package/bin/skills/llama-factory/references/advanced.md +1055 -0
- package/bin/skills/llama-factory/references/getting_started.md +349 -0
- package/bin/skills/llama-factory/references/index.md +19 -0
- package/bin/skills/llama-factory/references/other.md +31 -0
- package/bin/skills/llamaguard/SKILL.md +337 -0
- package/bin/skills/llamaindex/SKILL.md +569 -0
- package/bin/skills/llamaindex/references/agents.md +83 -0
- package/bin/skills/llamaindex/references/data_connectors.md +108 -0
- package/bin/skills/llamaindex/references/query_engines.md +406 -0
- package/bin/skills/llava/SKILL.md +304 -0
- package/bin/skills/llava/references/training.md +197 -0
- package/bin/skills/lm-evaluation-harness/SKILL.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
- package/bin/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
- package/bin/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
- package/bin/skills/long-context/SKILL.md +536 -0
- package/bin/skills/long-context/references/extension_methods.md +468 -0
- package/bin/skills/long-context/references/fine_tuning.md +611 -0
- package/bin/skills/long-context/references/rope.md +402 -0
- package/bin/skills/mamba/SKILL.md +260 -0
- package/bin/skills/mamba/references/architecture-details.md +206 -0
- package/bin/skills/mamba/references/benchmarks.md +255 -0
- package/bin/skills/mamba/references/training-guide.md +388 -0
- package/bin/skills/megatron-core/SKILL.md +366 -0
- package/bin/skills/megatron-core/references/benchmarks.md +249 -0
- package/bin/skills/megatron-core/references/parallelism-guide.md +404 -0
- package/bin/skills/megatron-core/references/production-examples.md +473 -0
- package/bin/skills/megatron-core/references/training-recipes.md +547 -0
- package/bin/skills/miles/SKILL.md +315 -0
- package/bin/skills/miles/references/api-reference.md +141 -0
- package/bin/skills/miles/references/troubleshooting.md +352 -0
- package/bin/skills/mlflow/SKILL.md +704 -0
- package/bin/skills/mlflow/references/deployment.md +744 -0
- package/bin/skills/mlflow/references/model-registry.md +770 -0
- package/bin/skills/mlflow/references/tracking.md +680 -0
- package/bin/skills/modal/SKILL.md +341 -0
- package/bin/skills/modal/references/advanced-usage.md +503 -0
- package/bin/skills/modal/references/troubleshooting.md +494 -0
- package/bin/skills/model-merging/SKILL.md +539 -0
- package/bin/skills/model-merging/references/evaluation.md +462 -0
- package/bin/skills/model-merging/references/examples.md +428 -0
- package/bin/skills/model-merging/references/methods.md +352 -0
- package/bin/skills/model-pruning/SKILL.md +495 -0
- package/bin/skills/model-pruning/references/wanda.md +347 -0
- package/bin/skills/moe-training/SKILL.md +526 -0
- package/bin/skills/moe-training/references/architectures.md +432 -0
- package/bin/skills/moe-training/references/inference.md +348 -0
- package/bin/skills/moe-training/references/training.md +425 -0
- package/bin/skills/nanogpt/SKILL.md +290 -0
- package/bin/skills/nanogpt/references/architecture.md +382 -0
- package/bin/skills/nanogpt/references/data.md +476 -0
- package/bin/skills/nanogpt/references/training.md +564 -0
- package/bin/skills/nemo-curator/SKILL.md +383 -0
- package/bin/skills/nemo-curator/references/deduplication.md +87 -0
- package/bin/skills/nemo-curator/references/filtering.md +102 -0
- package/bin/skills/nemo-evaluator/SKILL.md +494 -0
- package/bin/skills/nemo-evaluator/references/adapter-system.md +340 -0
- package/bin/skills/nemo-evaluator/references/configuration.md +447 -0
- package/bin/skills/nemo-evaluator/references/custom-benchmarks.md +315 -0
- package/bin/skills/nemo-evaluator/references/execution-backends.md +361 -0
- package/bin/skills/nemo-guardrails/SKILL.md +297 -0
- package/bin/skills/nnsight/SKILL.md +436 -0
- package/bin/skills/nnsight/references/README.md +78 -0
- package/bin/skills/nnsight/references/api.md +344 -0
- package/bin/skills/nnsight/references/tutorials.md +300 -0
- package/bin/skills/openrlhf/SKILL.md +249 -0
- package/bin/skills/openrlhf/references/algorithm-comparison.md +404 -0
- package/bin/skills/openrlhf/references/custom-rewards.md +530 -0
- package/bin/skills/openrlhf/references/hybrid-engine.md +287 -0
- package/bin/skills/openrlhf/references/multi-node-training.md +454 -0
- package/bin/skills/outlines/SKILL.md +652 -0
- package/bin/skills/outlines/references/backends.md +615 -0
- package/bin/skills/outlines/references/examples.md +773 -0
- package/bin/skills/outlines/references/json_generation.md +652 -0
- package/bin/skills/peft/SKILL.md +431 -0
- package/bin/skills/peft/references/advanced-usage.md +514 -0
- package/bin/skills/peft/references/troubleshooting.md +480 -0
- package/bin/skills/phoenix/SKILL.md +475 -0
- package/bin/skills/phoenix/references/advanced-usage.md +619 -0
- package/bin/skills/phoenix/references/troubleshooting.md +538 -0
- package/bin/skills/pinecone/SKILL.md +358 -0
- package/bin/skills/pinecone/references/deployment.md +181 -0
- package/bin/skills/pytorch-fsdp/SKILL.md +126 -0
- package/bin/skills/pytorch-fsdp/references/index.md +7 -0
- package/bin/skills/pytorch-fsdp/references/other.md +4249 -0
- package/bin/skills/pytorch-lightning/SKILL.md +346 -0
- package/bin/skills/pytorch-lightning/references/callbacks.md +436 -0
- package/bin/skills/pytorch-lightning/references/distributed.md +490 -0
- package/bin/skills/pytorch-lightning/references/hyperparameter-tuning.md +556 -0
- package/bin/skills/pyvene/SKILL.md +473 -0
- package/bin/skills/pyvene/references/README.md +73 -0
- package/bin/skills/pyvene/references/api.md +383 -0
- package/bin/skills/pyvene/references/tutorials.md +376 -0
- package/bin/skills/qdrant/SKILL.md +493 -0
- package/bin/skills/qdrant/references/advanced-usage.md +648 -0
- package/bin/skills/qdrant/references/troubleshooting.md +631 -0
- package/bin/skills/ray-data/SKILL.md +326 -0
- package/bin/skills/ray-data/references/integration.md +82 -0
- package/bin/skills/ray-data/references/transformations.md +83 -0
- package/bin/skills/ray-train/SKILL.md +406 -0
- package/bin/skills/ray-train/references/multi-node.md +628 -0
- package/bin/skills/rwkv/SKILL.md +260 -0
- package/bin/skills/rwkv/references/architecture-details.md +344 -0
- package/bin/skills/rwkv/references/rwkv7.md +386 -0
- package/bin/skills/rwkv/references/state-management.md +369 -0
- package/bin/skills/saelens/SKILL.md +386 -0
- package/bin/skills/saelens/references/README.md +70 -0
- package/bin/skills/saelens/references/api.md +333 -0
- package/bin/skills/saelens/references/tutorials.md +318 -0
- package/bin/skills/segment-anything/SKILL.md +500 -0
- package/bin/skills/segment-anything/references/advanced-usage.md +589 -0
- package/bin/skills/segment-anything/references/troubleshooting.md +484 -0
- package/bin/skills/sentence-transformers/SKILL.md +255 -0
- package/bin/skills/sentence-transformers/references/models.md +123 -0
- package/bin/skills/sentencepiece/SKILL.md +235 -0
- package/bin/skills/sentencepiece/references/algorithms.md +200 -0
- package/bin/skills/sentencepiece/references/training.md +304 -0
- package/bin/skills/sglang/SKILL.md +442 -0
- package/bin/skills/sglang/references/deployment.md +490 -0
- package/bin/skills/sglang/references/radix-attention.md +413 -0
- package/bin/skills/sglang/references/structured-generation.md +541 -0
- package/bin/skills/simpo/SKILL.md +219 -0
- package/bin/skills/simpo/references/datasets.md +478 -0
- package/bin/skills/simpo/references/hyperparameters.md +452 -0
- package/bin/skills/simpo/references/loss-functions.md +350 -0
- package/bin/skills/skypilot/SKILL.md +509 -0
- package/bin/skills/skypilot/references/advanced-usage.md +491 -0
- package/bin/skills/skypilot/references/troubleshooting.md +570 -0
- package/bin/skills/slime/SKILL.md +464 -0
- package/bin/skills/slime/references/api-reference.md +392 -0
- package/bin/skills/slime/references/troubleshooting.md +386 -0
- package/bin/skills/speculative-decoding/SKILL.md +467 -0
- package/bin/skills/speculative-decoding/references/lookahead.md +309 -0
- package/bin/skills/speculative-decoding/references/medusa.md +350 -0
- package/bin/skills/stable-diffusion/SKILL.md +519 -0
- package/bin/skills/stable-diffusion/references/advanced-usage.md +716 -0
- package/bin/skills/stable-diffusion/references/troubleshooting.md +555 -0
- package/bin/skills/tensorboard/SKILL.md +629 -0
- package/bin/skills/tensorboard/references/integrations.md +638 -0
- package/bin/skills/tensorboard/references/profiling.md +545 -0
- package/bin/skills/tensorboard/references/visualization.md +620 -0
- package/bin/skills/tensorrt-llm/SKILL.md +187 -0
- package/bin/skills/tensorrt-llm/references/multi-gpu.md +298 -0
- package/bin/skills/tensorrt-llm/references/optimization.md +242 -0
- package/bin/skills/tensorrt-llm/references/serving.md +470 -0
- package/bin/skills/tinker/SKILL.md +362 -0
- package/bin/skills/tinker/references/api-reference.md +168 -0
- package/bin/skills/tinker/references/getting-started.md +157 -0
- package/bin/skills/tinker/references/loss-functions.md +163 -0
- package/bin/skills/tinker/references/models-and-lora.md +139 -0
- package/bin/skills/tinker/references/recipes.md +280 -0
- package/bin/skills/tinker/references/reinforcement-learning.md +212 -0
- package/bin/skills/tinker/references/rendering.md +243 -0
- package/bin/skills/tinker/references/supervised-learning.md +232 -0
- package/bin/skills/tinker-training-cost/SKILL.md +187 -0
- package/bin/skills/tinker-training-cost/scripts/calculate_cost.py +123 -0
- package/bin/skills/torchforge/SKILL.md +433 -0
- package/bin/skills/torchforge/references/api-reference.md +327 -0
- package/bin/skills/torchforge/references/troubleshooting.md +409 -0
- package/bin/skills/torchtitan/SKILL.md +358 -0
- package/bin/skills/torchtitan/references/checkpoint.md +181 -0
- package/bin/skills/torchtitan/references/custom-models.md +258 -0
- package/bin/skills/torchtitan/references/float8.md +133 -0
- package/bin/skills/torchtitan/references/fsdp.md +126 -0
- package/bin/skills/transformer-lens/SKILL.md +346 -0
- package/bin/skills/transformer-lens/references/README.md +54 -0
- package/bin/skills/transformer-lens/references/api.md +362 -0
- package/bin/skills/transformer-lens/references/tutorials.md +339 -0
- package/bin/skills/trl-fine-tuning/SKILL.md +455 -0
- package/bin/skills/trl-fine-tuning/references/dpo-variants.md +227 -0
- package/bin/skills/trl-fine-tuning/references/online-rl.md +82 -0
- package/bin/skills/trl-fine-tuning/references/reward-modeling.md +122 -0
- package/bin/skills/trl-fine-tuning/references/sft-training.md +168 -0
- package/bin/skills/unsloth/SKILL.md +80 -0
- package/bin/skills/unsloth/references/index.md +7 -0
- package/bin/skills/unsloth/references/llms-full.md +16799 -0
- package/bin/skills/unsloth/references/llms-txt.md +12044 -0
- package/bin/skills/unsloth/references/llms.md +82 -0
- package/bin/skills/verl/SKILL.md +391 -0
- package/bin/skills/verl/references/api-reference.md +301 -0
- package/bin/skills/verl/references/troubleshooting.md +391 -0
- package/bin/skills/vllm/SKILL.md +364 -0
- package/bin/skills/vllm/references/optimization.md +226 -0
- package/bin/skills/vllm/references/quantization.md +284 -0
- package/bin/skills/vllm/references/server-deployment.md +255 -0
- package/bin/skills/vllm/references/troubleshooting.md +447 -0
- package/bin/skills/weights-and-biases/SKILL.md +590 -0
- package/bin/skills/weights-and-biases/references/artifacts.md +584 -0
- package/bin/skills/weights-and-biases/references/integrations.md +700 -0
- package/bin/skills/weights-and-biases/references/sweeps.md +847 -0
- package/bin/skills/whisper/SKILL.md +317 -0
- package/bin/skills/whisper/references/languages.md +189 -0
- package/bin/synsc +0 -0
- package/package.json +10 -0
|
@@ -0,0 +1,4249 @@
|
|
|
1
|
+
# Pytorch-Fsdp - Other
|
|
2
|
+
|
|
3
|
+
**Pages:** 15
|
|
4
|
+
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
## Distributed Data Parallel#
|
|
8
|
+
|
|
9
|
+
**URL:** https://pytorch.org/docs/stable/notes/ddp.html
|
|
10
|
+
|
|
11
|
+
**Contents:**
|
|
12
|
+
- Distributed Data Parallel#
|
|
13
|
+
- Example#
|
|
14
|
+
- Internal Design#
|
|
15
|
+
- Implementation#
|
|
16
|
+
- ProcessGroup#
|
|
17
|
+
- DistributedDataParallel#
|
|
18
|
+
- TorchDynamo DDPOptimizer#
|
|
19
|
+
|
|
20
|
+
Created On: Jan 15, 2020 | Last Updated On: Jan 25, 2024
|
|
21
|
+
|
|
22
|
+
The implementation of torch.nn.parallel.DistributedDataParallel evolves over time. This design note is written based on the state as of v1.4.
|
|
23
|
+
|
|
24
|
+
torch.nn.parallel.DistributedDataParallel (DDP) transparently performs distributed data parallel training. This page describes how it works and reveals implementation details.
|
|
25
|
+
|
|
26
|
+
Let us start with a simple torch.nn.parallel.DistributedDataParallel example. This example uses a torch.nn.Linear as the local model, wraps it with DDP, and then runs one forward pass, one backward pass, and an optimizer step on the DDP model. After that, parameters on the local model will be updated, and all models on different processes should be exactly the same.
|
|
27
|
+
|
|
28
|
+
DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapper before compiling the model, such that torchdynamo can apply DDPOptimizer (graph-break optimizations) based on DDP bucket sizes. (See TorchDynamo DDPOptimizer for more information.)
|
|
29
|
+
|
|
30
|
+
This section reveals how it works under the hood of torch.nn.parallel.DistributedDataParallel by diving into details of every step in one iteration.
|
|
31
|
+
|
|
32
|
+
Prerequisite: DDP relies on c10d ProcessGroup for communications. Hence, applications must create ProcessGroup instances before constructing DDP.
|
|
33
|
+
|
|
34
|
+
Construction: The DDP constructor takes a reference to the local module, and broadcasts state_dict() from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state. Then, each DDP process creates a local Reducer, which later will take care of the gradients synchronization during the backward pass. To improve communication efficiency, the Reducer organizes parameter gradients into buckets, and reduces one bucket at a time. Bucket size can be configured by setting the bucket_cap_mb argument in DDP constructor. The mapping from parameter gradients to buckets is determined at the construction time, based on the bucket size limit and parameter sizes. Model parameters are allocated into buckets in (roughly) the reverse order of Model.parameters() from the given model. The reason for using the reverse order is because DDP expects gradients to become ready during the backward pass in approximately that order. The figure below shows an example. Note that, the grad0 and grad1 are in bucket1, and the other two gradients are in bucket0. Of course, this assumption might not always be true, and when that happens it could hurt DDP backward speed as the Reducer cannot kick off the communication at the earliest possible time. Besides bucketing, the Reducer also registers autograd hooks during construction, one hook per parameter. These hooks will be triggered during the backward pass when the gradient becomes ready.
|
|
35
|
+
|
|
36
|
+
Forward Pass: The DDP takes the input and passes it to the local model, and then analyzes the output from the local model if find_unused_parameters is set to True. This mode allows running backward on a subgraph of the model, and DDP finds out which parameters are involved in the backward pass by traversing the autograd graph from the model output and marking all unused parameters as ready for reduction. During the backward pass, the Reducer would only wait for unready parameters, but it would still reduce all buckets. Marking a parameter gradient as ready does not help DDP skip buckets as for now, but it will prevent DDP from waiting for absent gradients forever during the backward pass. Note that traversing the autograd graph introduces extra overheads, so applications should only set find_unused_parameters to True when necessary.
|
|
37
|
+
|
|
38
|
+
Backward Pass: The backward() function is directly invoked on the loss Tensor, which is out of DDP’s control, and DDP uses autograd hooks registered at construction time to trigger gradients synchronizations. When one gradient becomes ready, its corresponding DDP hook on that grad accumulator will fire, and DDP will then mark that parameter gradient as ready for reduction. When gradients in one bucket are all ready, the Reducer kicks off an asynchronous allreduce on that bucket to calculate mean of gradients across all processes. When all buckets are ready, the Reducer will block waiting for all allreduce operations to finish. When this is done, averaged gradients are written to the param.grad field of all parameters. So after the backward pass, the grad field on the same corresponding parameter across different DDP processes should be the same.
|
|
39
|
+
|
|
40
|
+
Optimizer Step: From the optimizer’s perspective, it is optimizing a local model. Model replicas on all DDP processes can keep in sync because they all start from the same state and they have the same averaged gradients in every iteration.
|
|
41
|
+
|
|
42
|
+
DDP requires Reducer instances on all processes to invoke allreduce in exactly the same order, which is done by always running allreduce in the bucket index order instead of actual bucket ready order. Mismatched allreduce order across processes can lead to wrong results or DDP backward hang.
|
|
43
|
+
|
|
44
|
+
Below are pointers to the DDP implementation components. The stacked graph shows the structure of the code.
|
|
45
|
+
|
|
46
|
+
ProcessGroup.hpp: contains the abstract API of all process group implementations. The c10d library provides 3 implementations out of the box, namely, ProcessGroupGloo, ProcessGroupNCCL, and ProcessGroupMPI. DistributedDataParallel uses ProcessGroup::broadcast() to send model states from the process with rank 0 to others during initialization and ProcessGroup::allreduce() to sum gradients.
|
|
47
|
+
|
|
48
|
+
Store.hpp: assists the rendezvous service for process group instances to find each other.
|
|
49
|
+
|
|
50
|
+
distributed.py: is the Python entry point for DDP. It implements the initialization steps and the forward function for the nn.parallel.DistributedDataParallel module which call into C++ libraries. Its _sync_param function performs intra-process parameter synchronization when one DDP process works on multiple devices, and it also broadcasts model buffers from the process with rank 0 to all other processes. The inter-process parameter synchronization happens in Reducer.cpp.
|
|
51
|
+
|
|
52
|
+
comm.h: implements the coalesced broadcast helper function which is invoked to broadcast model states during initialization and synchronize model buffers before the forward pass.
|
|
53
|
+
|
|
54
|
+
reducer.h: provides the core implementation for gradient synchronization in the backward pass. It has three entry point functions:
|
|
55
|
+
|
|
56
|
+
Reducer: The constructor is called in distributed.py which registers Reducer::autograd_hook() to gradient accumulators.
|
|
57
|
+
|
|
58
|
+
autograd_hook() function will be invoked by the autograd engine when a gradient becomes ready.
|
|
59
|
+
|
|
60
|
+
prepare_for_backward() is called at the end of DDP forward pass in distributed.py. It traverses the autograd graph to find unused parameters when find_unused_parameters is set to True in DDP constructor.
|
|
61
|
+
|
|
62
|
+
DDP’s performance advantage comes from overlapping allreduce collectives with computations during backwards. AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph, because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes.
|
|
63
|
+
|
|
64
|
+
TorchDynamo’s DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP’s allreduce buckets during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP’s allreduce hooks to fire in-between sections of backwards, and schedule communications to overlap with compute.
|
|
65
|
+
|
|
66
|
+
See this blog post for a more in-depth explanation and experimental results, or read the docs and code at torch/_dynamo/optimizations/distributed.py
|
|
67
|
+
|
|
68
|
+
To Debug DDPOptimizer, set TORCH_LOGS=’ddp_graphs’ for full graph dumps. For logs without graphs, add any of ‘dynamo’, ‘distributed’, or ‘dist_ddp’ to TORCH_LOGS (for basic info about bucket boundaries). To disable DDPOptimizer, set torch._dynamo.config.optimize_ddp=False. DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation.
|
|
69
|
+
|
|
70
|
+
---
|
|
71
|
+
|
|
72
|
+
## PyTorch documentation#
|
|
73
|
+
|
|
74
|
+
**URL:** https://pytorch.org/docs/stable/
|
|
75
|
+
|
|
76
|
+
**Contents:**
|
|
77
|
+
- PyTorch documentation#
|
|
78
|
+
- Indices and tables#
|
|
79
|
+
|
|
80
|
+
PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
|
81
|
+
|
|
82
|
+
Features described in this documentation are classified by release status:
|
|
83
|
+
|
|
84
|
+
Stable (API-Stable): These features will be maintained long-term and there should generally be no major performance limitations or gaps in documentation. We also expect to maintain backwards compatibility (although breaking changes can happen and notice will be given one release ahead of time).
|
|
85
|
+
|
|
86
|
+
Unstable (API-Unstable): Encompasses all features that are under active development where APIs may change based on user feedback, requisite performance improvements or because coverage across operators is not yet complete. The APIs and performance characteristics of these features may change.
|
|
87
|
+
|
|
88
|
+
---
|
|
89
|
+
|
|
90
|
+
## Generic Join Context Manager#
|
|
91
|
+
|
|
92
|
+
**URL:** https://pytorch.org/docs/stable/distributed.algorithms.join.html
|
|
93
|
+
|
|
94
|
+
**Contents:**
|
|
95
|
+
- Generic Join Context Manager#
|
|
96
|
+
|
|
97
|
+
Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025
|
|
98
|
+
|
|
99
|
+
The generic join context manager facilitates distributed training on uneven inputs. This page outlines the API of the relevant classes: Join, Joinable, and JoinHook. For a tutorial, see Distributed Training with Uneven Inputs Using the Join Context Manager.
|
|
100
|
+
|
|
101
|
+
This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
|
|
102
|
+
|
|
103
|
+
These hooks should shadow the collective communications of non-joined processes to prevent hanging and erroring and to ensure algorithmic correctness. Refer to JoinHook for details about the hook definition.
|
|
104
|
+
|
|
105
|
+
The context manager requires each participating Joinable to call the method notify_join_context() before its own per- iteration collective communications to ensure correctness.
|
|
106
|
+
|
|
107
|
+
The context manager requires that all process_group attributes in the JoinHook objects are the same. If there are multiple JoinHook objects, then the device of the first is used. The process group and device information is used for checking for non- joined processes and for notifying processes to throw an exception if throw_on_early_termination is enabled, both of which using an all- reduce.
|
|
108
|
+
|
|
109
|
+
joinables (List[Joinable]) – a list of the participating Joinable s; their hooks are iterated over in the given order.
|
|
110
|
+
|
|
111
|
+
enable (bool) – a flag enabling uneven input detection; setting to False disables the context manager’s functionality and should only be set when the user knows the inputs will not be uneven (default: True).
|
|
112
|
+
|
|
113
|
+
throw_on_early_termination (bool) – a flag controlling whether to throw an exception upon detecting uneven inputs (default: False).
|
|
114
|
+
|
|
115
|
+
Notifies the join context manager that the calling process has not yet joined.
|
|
116
|
+
|
|
117
|
+
Then, if throw_on_early_termination=True, checks if uneven inputs have been detected (i.e. if one process has already joined) and throws an exception if so.
|
|
118
|
+
|
|
119
|
+
This method should be called from a Joinable object before its per-iteration collective communications. For example, this should be called at the beginning of the forward pass in DistributedDataParallel.
|
|
120
|
+
|
|
121
|
+
Only the first Joinable object passed into the context manager performs the collective communications in this method, and for the others, this method is vacuous.
|
|
122
|
+
|
|
123
|
+
joinable (Joinable) – the Joinable object calling this method.
|
|
124
|
+
|
|
125
|
+
An async work handle for the all-reduce meant to notify the context manager that the process has not yet joined if joinable is the first one passed into the context manager; None otherwise.
|
|
126
|
+
|
|
127
|
+
This defines an abstract base class for joinable classes.
|
|
128
|
+
|
|
129
|
+
A joinable class (inheriting from Joinable) should implement join_hook(), which returns a JoinHook instance, in addition to join_device() and join_process_group() that return device and process group information, respectively.
|
|
130
|
+
|
|
131
|
+
Return the device from which to perform collective communications needed by the join context manager.
|
|
132
|
+
|
|
133
|
+
Return a JoinHook instance for the given Joinable.
|
|
134
|
+
|
|
135
|
+
kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs.
|
|
136
|
+
|
|
137
|
+
Returns the process group for the collective communications needed by the join context manager itself.
|
|
138
|
+
|
|
139
|
+
This defines a join hook, which provides two entry points in the join context manager.
|
|
140
|
+
|
|
141
|
+
Entry points : a main hook, which is called repeatedly while there exists a non-joined process, and a post-hook, which is called once all processes have joined.
|
|
142
|
+
|
|
143
|
+
To implement a join hook for the generic join context manager, define a class that inherits from JoinHook and override main_hook() and post_hook() as appropriate.
|
|
144
|
+
|
|
145
|
+
Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
|
|
146
|
+
|
|
147
|
+
Training iteration i.e., in one forward pass, backward pass, and optimizer step.
|
|
148
|
+
|
|
149
|
+
Call hook after all processes have joined.
|
|
150
|
+
|
|
151
|
+
It is passed an additional bool argument is_last_joiner, which indicates if the rank is one of the last to join.
|
|
152
|
+
|
|
153
|
+
is_last_joiner (bool) – True if the rank is one of the last to join; False otherwise.
|
|
154
|
+
|
|
155
|
+
---
|
|
156
|
+
|
|
157
|
+
## Experimental Object Oriented Distributed API#
|
|
158
|
+
|
|
159
|
+
**URL:** https://pytorch.org/docs/stable/distributed._dist2.html
|
|
160
|
+
|
|
161
|
+
**Contents:**
|
|
162
|
+
- Experimental Object Oriented Distributed API#
|
|
163
|
+
|
|
164
|
+
Created On: Jul 09, 2025 | Last Updated On: Jul 30, 2025
|
|
165
|
+
|
|
166
|
+
This is an experimental new API for PyTorch Distributed. This is actively in development and subject to change or deletion entirely.
|
|
167
|
+
|
|
168
|
+
This is intended as a proving ground for more flexible and object oriented distributed APIs.
|
|
169
|
+
|
|
170
|
+
Bases: pybind11_object
|
|
171
|
+
|
|
172
|
+
A ProcessGroup is a communication primitive that allows for collective operations across a group of processes.
|
|
173
|
+
|
|
174
|
+
This is a base class that provides the interface for all ProcessGroups. It is not meant to be used directly, but rather extended by subclasses.
|
|
175
|
+
|
|
176
|
+
Bases: pybind11_object
|
|
177
|
+
|
|
178
|
+
The type of the backend used for the process group.
|
|
179
|
+
|
|
180
|
+
abort all operations and connections if supported by the backend
|
|
181
|
+
|
|
182
|
+
allgather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], input_tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.AllgatherOptions = <torch._C._distributed_c10d.AllgatherOptions object at 0x7f0162b6b9b0>) -> c10d::Work
|
|
183
|
+
|
|
184
|
+
Allgathers the input tensors from all processes across the process group.
|
|
185
|
+
|
|
186
|
+
See torch.distributed.all_gather() for more details.
|
|
187
|
+
|
|
188
|
+
allgather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensor: torch.Tensor, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
189
|
+
|
|
190
|
+
Allgathers the input tensors from all processes across the process group.
|
|
191
|
+
|
|
192
|
+
See torch.distributed.all_gather() for more details.
|
|
193
|
+
|
|
194
|
+
Allgathers the input tensors from all processes across the process group.
|
|
195
|
+
|
|
196
|
+
See torch.distributed.all_gather() for more details.
|
|
197
|
+
|
|
198
|
+
Allgathers the input tensors from all processes across the process group.
|
|
199
|
+
|
|
200
|
+
See torch.distributed.all_gather() for more details.
|
|
201
|
+
|
|
202
|
+
allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.AllreduceOptions = <torch._C._distributed_c10d.AllreduceOptions object at 0x7f0162745db0>) -> c10d::Work
|
|
203
|
+
|
|
204
|
+
Allreduces the provided tensors across all processes in the process group.
|
|
205
|
+
|
|
206
|
+
See torch.distributed.all_reduce() for more details.
|
|
207
|
+
|
|
208
|
+
allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], op: torch._C._distributed_c10d.ReduceOp = <RedOpType.SUM: 0>, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
209
|
+
|
|
210
|
+
Allreduces the provided tensors across all processes in the process group.
|
|
211
|
+
|
|
212
|
+
See torch.distributed.all_reduce() for more details.
|
|
213
|
+
|
|
214
|
+
allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensor: torch.Tensor, op: torch._C._distributed_c10d.ReduceOp = <RedOpType.SUM: 0>, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
215
|
+
|
|
216
|
+
Allreduces the provided tensors across all processes in the process group.
|
|
217
|
+
|
|
218
|
+
See torch.distributed.all_reduce() for more details.
|
|
219
|
+
|
|
220
|
+
Allreduces the provided tensors across all processes in the process group.
|
|
221
|
+
|
|
222
|
+
See torch.distributed.all_reduce() for more details.
|
|
223
|
+
|
|
224
|
+
Alltoalls the input tensors from all processes across the process group.
|
|
225
|
+
|
|
226
|
+
See torch.distributed.all_to_all() for more details.
|
|
227
|
+
|
|
228
|
+
alltoall_base(self: torch._C._distributed_c10d.ProcessGroup, output: torch.Tensor, input: torch.Tensor, output_split_sizes: collections.abc.Sequence[typing.SupportsInt], input_split_sizes: collections.abc.Sequence[typing.SupportsInt], opts: torch._C._distributed_c10d.AllToAllOptions = <torch._C._distributed_c10d.AllToAllOptions object at 0x7f0162b79d30>) -> c10d::Work
|
|
229
|
+
|
|
230
|
+
Alltoalls the input tensors from all processes across the process group.
|
|
231
|
+
|
|
232
|
+
See torch.distributed.all_to_all() for more details.
|
|
233
|
+
|
|
234
|
+
alltoall_base(self: torch._C._distributed_c10d.ProcessGroup, output: torch.Tensor, input: torch.Tensor, output_split_sizes: collections.abc.Sequence[typing.SupportsInt], input_split_sizes: collections.abc.Sequence[typing.SupportsInt], timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
235
|
+
|
|
236
|
+
Alltoalls the input tensors from all processes across the process group.
|
|
237
|
+
|
|
238
|
+
See torch.distributed.all_to_all() for more details.
|
|
239
|
+
|
|
240
|
+
barrier(self: torch._C._distributed_c10d.ProcessGroup, opts: torch._C._distributed_c10d.BarrierOptions = <torch._C._distributed_c10d.BarrierOptions object at 0x7f0162745ab0>) -> c10d::Work
|
|
241
|
+
|
|
242
|
+
then all leave the call together.
|
|
243
|
+
|
|
244
|
+
See torch.distributed.barrier() for more details.
|
|
245
|
+
|
|
246
|
+
barrier(self: torch._C._distributed_c10d.ProcessGroup, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
247
|
+
|
|
248
|
+
then all leave the call together.
|
|
249
|
+
|
|
250
|
+
See torch.distributed.barrier() for more details.
|
|
251
|
+
|
|
252
|
+
broadcast(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.BroadcastOptions = <torch._C._distributed_c10d.BroadcastOptions object at 0x7f0162b7afb0>) -> c10d::Work
|
|
253
|
+
|
|
254
|
+
Broadcasts the tensor to all processes in the process group.
|
|
255
|
+
|
|
256
|
+
See torch.distributed.broadcast() for more details.
|
|
257
|
+
|
|
258
|
+
broadcast(self: torch._C._distributed_c10d.ProcessGroup, tensor: torch.Tensor, root: typing.SupportsInt, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
259
|
+
|
|
260
|
+
Broadcasts the tensor to all processes in the process group.
|
|
261
|
+
|
|
262
|
+
See torch.distributed.broadcast() for more details.
|
|
263
|
+
|
|
264
|
+
gather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], input_tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.GatherOptions = <torch._C._distributed_c10d.GatherOptions object at 0x7f0162c301f0>) -> c10d::Work
|
|
265
|
+
|
|
266
|
+
Gathers the input tensors from all processes across the process group.
|
|
267
|
+
|
|
268
|
+
See torch.distributed.gather() for more details.
|
|
269
|
+
|
|
270
|
+
gather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensor: torch.Tensor, root: typing.SupportsInt, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
271
|
+
|
|
272
|
+
Gathers the input tensors from all processes across the process group.
|
|
273
|
+
|
|
274
|
+
See torch.distributed.gather() for more details.
|
|
275
|
+
|
|
276
|
+
Get the store of this process group.
|
|
277
|
+
|
|
278
|
+
Gets this process group description
|
|
279
|
+
|
|
280
|
+
(Gets this process group name. It’s cluster unique)
|
|
281
|
+
|
|
282
|
+
then all leave the call together.
|
|
283
|
+
|
|
284
|
+
See torch.distributed.monitored_barrier() for more details.
|
|
285
|
+
|
|
286
|
+
Get the name of this process group.
|
|
287
|
+
|
|
288
|
+
Get the rank of this process group.
|
|
289
|
+
|
|
290
|
+
Receives the tensor from the specified rank.
|
|
291
|
+
|
|
292
|
+
See torch.distributed.recv() for more details.
|
|
293
|
+
|
|
294
|
+
Receives the tensor from any source.
|
|
295
|
+
|
|
296
|
+
See torch.distributed.recv() for more details.
|
|
297
|
+
|
|
298
|
+
reduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.ReduceOptions = <torch._C._distributed_c10d.ReduceOptions object at 0x7f0162bce3f0>) -> c10d::Work
|
|
299
|
+
|
|
300
|
+
Reduces the provided tensors across all processes in the process group.
|
|
301
|
+
|
|
302
|
+
See torch.distributed.reduce() for more details.
|
|
303
|
+
|
|
304
|
+
reduce(self: torch._C._distributed_c10d.ProcessGroup, tensor: torch.Tensor, root: typing.SupportsInt, op: torch._C._distributed_c10d.ReduceOp = <RedOpType.SUM: 0>, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
305
|
+
|
|
306
|
+
Reduces the provided tensors across all processes in the process group.
|
|
307
|
+
|
|
308
|
+
See torch.distributed.reduce() for more details.
|
|
309
|
+
|
|
310
|
+
reduce_scatter(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], opts: torch._C._distributed_c10d.ReduceScatterOptions = <torch._C._distributed_c10d.ReduceScatterOptions object at 0x7f0162ee5cf0>) -> c10d::Work
|
|
311
|
+
|
|
312
|
+
Reduces and scatters the input tensors from all processes across the process group.
|
|
313
|
+
|
|
314
|
+
See torch.distributed.reduce_scatter() for more details.
|
|
315
|
+
|
|
316
|
+
reduce_scatter(self: torch._C._distributed_c10d.ProcessGroup, output: torch.Tensor, input: collections.abc.Sequence[torch.Tensor], op: torch._C._distributed_c10d.ReduceOp = <RedOpType.SUM: 0>, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
317
|
+
|
|
318
|
+
Reduces and scatters the input tensors from all processes across the process group.
|
|
319
|
+
|
|
320
|
+
See torch.distributed.reduce_scatter() for more details.
|
|
321
|
+
|
|
322
|
+
Reduces and scatters the input tensors from all processes across the process group.
|
|
323
|
+
|
|
324
|
+
See torch.distributed.reduce_scatter() for more details.
|
|
325
|
+
|
|
326
|
+
scatter(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], opts: torch._C._distributed_c10d.ScatterOptions = <torch._C._distributed_c10d.ScatterOptions object at 0x7f0162b879f0>) -> c10d::Work
|
|
327
|
+
|
|
328
|
+
Scatters the input tensors from all processes across the process group.
|
|
329
|
+
|
|
330
|
+
See torch.distributed.scatter() for more details.
|
|
331
|
+
|
|
332
|
+
scatter(self: torch._C._distributed_c10d.ProcessGroup, output_tensor: torch.Tensor, input_tensors: collections.abc.Sequence[torch.Tensor], root: typing.SupportsInt, timeout: datetime.timedelta | None = None) -> c10d::Work
|
|
333
|
+
|
|
334
|
+
Scatters the input tensors from all processes across the process group.
|
|
335
|
+
|
|
336
|
+
See torch.distributed.scatter() for more details.
|
|
337
|
+
|
|
338
|
+
Sends the tensor to the specified rank.
|
|
339
|
+
|
|
340
|
+
See torch.distributed.send() for more details.
|
|
341
|
+
|
|
342
|
+
Sets the default timeout for all future operations.
|
|
343
|
+
|
|
344
|
+
shutdown the process group
|
|
345
|
+
|
|
346
|
+
Get the size of this process group.
|
|
347
|
+
|
|
348
|
+
Protocol for process group factories.
|
|
349
|
+
|
|
350
|
+
Get the current process group. Thread local method.
|
|
351
|
+
|
|
352
|
+
The current process group.
|
|
353
|
+
|
|
354
|
+
Create a new process group with the given backend and options. This group is independent and will not be globally registered and thus not usable via the standard torch.distributed.* APIs.
|
|
355
|
+
|
|
356
|
+
backend (str) – The backend to use for the process group.
|
|
357
|
+
|
|
358
|
+
timeout (timedelta) – The timeout for collective operations.
|
|
359
|
+
|
|
360
|
+
device (Union[str, device]) – The device to use for the process group.
|
|
361
|
+
|
|
362
|
+
**kwargs (object) – All remaining arguments are passed to the backend constructor. See the backend specific documentation for details.
|
|
363
|
+
|
|
364
|
+
Context manager for process groups. Thread local method.
|
|
365
|
+
|
|
366
|
+
pg (ProcessGroup) – The process group to use.
|
|
367
|
+
|
|
368
|
+
Generator[None, None, None]
|
|
369
|
+
|
|
370
|
+
Register a new process group backend.
|
|
371
|
+
|
|
372
|
+
name (str) – The name of the backend.
|
|
373
|
+
|
|
374
|
+
func (ProcessGroupFactory) – The function to create the process group.
|
|
375
|
+
|
|
376
|
+
---
|
|
377
|
+
|
|
378
|
+
## torch.distributed.fsdp.fully_shard#
|
|
379
|
+
|
|
380
|
+
**URL:** https://pytorch.org/docs/stable/distributed.fsdp.fully_shard.html
|
|
381
|
+
|
|
382
|
+
**Contents:**
|
|
383
|
+
- torch.distributed.fsdp.fully_shard#
|
|
384
|
+
- PyTorch FSDP2 (fully_shard)#
|
|
385
|
+
|
|
386
|
+
Created On: Dec 04, 2024 | Last Updated On: Jun 16, 2025
|
|
387
|
+
|
|
388
|
+
PyTorch FSDP2 (RFC) provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability
|
|
389
|
+
|
|
390
|
+
See the Getting Started with FSDP2 tutorial for more information.
|
|
391
|
+
|
|
392
|
+
If you are currently using FSDP1, consider migrating to FSDP2 using our migration guide.
|
|
393
|
+
|
|
394
|
+
The user contract for fully_shard(model) is as follows
|
|
395
|
+
|
|
396
|
+
For model initialization, fully_shard converts model.parameters() from plain torch.Tensor to DTensor in-place. The parameters are moved to the appropriate device according to the device mesh.
|
|
397
|
+
|
|
398
|
+
Before forward and backward passes, pre-forward/backward hooks are responsible for all-gathering the parameters and converting model.parameters() from DTensor to plain torch.Tensor.
|
|
399
|
+
|
|
400
|
+
After forward and backward passes, post-forward/backward hooks free the unsharded parameters (no communication needed) and convert model.parameters() from plain torch.Tensor back to DTensor.
|
|
401
|
+
|
|
402
|
+
For the optimizer, it must be initialized with the DTensor model.parameters(), and the optimizer step should be performed on DTensor parameters.
|
|
403
|
+
|
|
404
|
+
Call model(input) instead of model.forward(input) to trigger pre-forward hooks to all-gather parameters. To make model.forward(input) work, users must either call model.unshard() explicitly or use register_fsdp_forward_method(model, "forward") to register the forward method for hooking.
|
|
405
|
+
|
|
406
|
+
fully_shard groups parameters together for a single all-gather. User should apply fully_shard in a bottom-up manner. For example, in a Transformer model, fully_shard should be applied to each layer before applying it to the root model. When applied to the root model, fully_shard excludes model.parameters() from each layer and groups the remaining parameters (e.g., embeddings, output projection) into a single all-gather group.
|
|
407
|
+
|
|
408
|
+
type(model) is “unioned” with FSDPModule in-place. For example, if model is originally of type nn.Linear, then fully_shard changes type(model) from nn.Linear to FSDPLinear in-place. FSDPLinear is an instance of both nn.Linear and FSDPModule. It retains all methods of nn.Linear while also exposing FSDP2-specific APIs under FSDPModule, such as reshard() and unshard().
|
|
409
|
+
|
|
410
|
+
Fully Qualified Names (FQNs) for parameters remain unchanged. If we call model.state_dict(), the FQNs are the same before and after applying fully_shard. This is because fully_shard does not wrap the module but only registers hooks to the original module.
|
|
411
|
+
|
|
412
|
+
Compared to PyTorch FSDP1 (FullyShardedDataParallel):
|
|
413
|
+
|
|
414
|
+
FSDP2 uses DTensor-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1’s flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (using torch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1.
|
|
415
|
+
|
|
416
|
+
FSDP2 implements a different memory management approach to handle the multi-stream usages that avoids torch.Tensor.record_stream. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1’s limit_all_gathers=True.
|
|
417
|
+
|
|
418
|
+
FSDP2 exposes APIs for manual control over prefetching and collective scheduling, allowing power users more customization. See the methods on FSDPModule below for details.
|
|
419
|
+
|
|
420
|
+
FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly support full state dicts. Instead, users can reshard the sharded state dicts containing DTensor s to full state dicts themselves using DTensor APIs like DTensor.full_tensor() or by using higher-level APIs like PyTorch Distributed Checkpoint ‘s distributed state dict APIs. Also, some other args have been removed; see here for details.
|
|
421
|
+
|
|
422
|
+
The frontend API is fully_shard that can be called on a module:
|
|
423
|
+
|
|
424
|
+
Apply fully sharded data parallelism (FSDP) to module, where FSDP shards module parameters, gradients, and optimizer states across data parallel workers to save memory at the cost of communication.
|
|
425
|
+
|
|
426
|
+
At initialization, FSDP shards the module’s parameters across the data parallel workers given by mesh. Before forward, FSDP all-gathers the sharded parameters across the data-parallel workers to get the unsharded parameters for forward computation. If reshard_after_forward is True, then FSDP frees the unsharded parameters after forward and re-all-gathers them in backward before gradient computation. After gradient computation, FSDP frees the unsharded parameters and reduce-scatters the unsharded gradients across data-parallel workers.
|
|
427
|
+
|
|
428
|
+
This implementation represents the sharded parameters as DTensor s sharded on dim-0, while the unsharded parameters will be like the original parameters on module (e.g. torch.Tensor if originally torch.Tensor). A module forward pre-hook on module all-gathers the parameters, and a module forward hook on module frees them (if needed). Similar backward hooks all-gather parameters and later free parameters and reduce-scatter gradients.
|
|
429
|
+
|
|
430
|
+
Since grouping multiple tensors together for one collective is critical for communication efficiency, this implementation makes this grouping first class. Calling fully_shard() on module constructs one group that includes the parameters in module.parameters() except those already assigned to a group from an earlier call on a submodule. This means that fully_shard() should be called bottom-up on your model. Each group’s parameters are all-gathered in one collective, and its gradients are reduce-scattered in one collective. Partitioning the model into multiple groups (“layer by layer”) allows for peak memory savings and communication/computation overlap. Users generally should not call fully_shard() only on the topmost root module.
|
|
431
|
+
|
|
432
|
+
module (Union[nn.Module, List[nn.Module]) – The module or modules to shard with FSDP and group together for communication.
|
|
433
|
+
|
|
434
|
+
mesh (Optional[DeviceMesh]) – This data parallel mesh defines the sharding and device. If 1D, then parameters are fully sharded across the 1D mesh (FSDP) with (Shard(0),) placement. If 2D, then parameters are sharded across the 1st dim and replicated across the 0th dim (HSDP) with (Replicate(), Shard(0)) placement. The mesh’s device type gives the device type used for communication; if a CUDA or CUDA-like device type, then we use the current device.
|
|
435
|
+
|
|
436
|
+
reshard_after_forward (Optional[Union[bool, int]]) – This controls the parameter behavior after forward and can trade off memory and communication: If True, then this reshards parameters after forward and re-all-gathers in backward. If False, then this keeps the unsharded parameters in memory after forward and avoids the all-gather in backward. For best performance, we usually set False for the root module, because the root module is typically required immediately when the backward pass begins. If None, it is set to True for non-root modules and False for root modules. If an int, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of the mesh shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g. torch.cuda.device_count()). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting to True. After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if True; unsharded parameters if False; and the parameters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. For False or an int, this can be done by manually resharding via reshard().
|
|
437
|
+
|
|
438
|
+
This controls the parameter behavior after forward and can trade off memory and communication:
|
|
439
|
+
|
|
440
|
+
If True, then this reshards parameters after forward and re-all-gathers in backward.
|
|
441
|
+
|
|
442
|
+
If False, then this keeps the unsharded parameters in memory after forward and avoids the all-gather in backward. For best performance, we usually set False for the root module, because the root module is typically required immediately when the backward pass begins.
|
|
443
|
+
|
|
444
|
+
If None, it is set to True for non-root modules and False for root modules.
|
|
445
|
+
|
|
446
|
+
If an int, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of the mesh shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g. torch.cuda.device_count()). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting to True.
|
|
447
|
+
|
|
448
|
+
After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if True; unsharded parameters if False; and the parameters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. For False or an int, this can be done by manually resharding via reshard().
|
|
449
|
+
|
|
450
|
+
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – This callable can be used to override the sharding placement for a parameter to shard a parameter on a dimension other than dim-0. If this callable returns a Shard placement (not None), then FSDP will shard according to that placement (e.g. Shard(1)). If sharding on a nonzero dim, we currently require even sharding, i.e. the tensor dim size on that dim must be divisible by the FSDP shard mesh size.
|
|
451
|
+
|
|
452
|
+
mp_policy (MixedPrecisionPolicy) – This controls the mixed precision policy, which offers parameter/reduction mixed precision for this module. See MixedPrecisionPolicy for details.
|
|
453
|
+
|
|
454
|
+
offload_policy (OffloadPolicy) – This controls the offloading policy, which offers parameter/gradient/optimizer state offloading. See OffloadPolicy and its subclasses for details.
|
|
455
|
+
|
|
456
|
+
ignored_params (Optional[set[nn.Parameter]]) – Optional(Set[nn.Parameter]): The set of parameters to be ignored by FSDP. They will not be sharded, nor moved to the device during init, nor have their gradients reduced in backward.
|
|
457
|
+
|
|
458
|
+
The module with FSDP applied (in-place).
|
|
459
|
+
|
|
460
|
+
Reshards the module’s parameters, freeing the unsharded parameters if they are allocated and registering the sharded parameters to the module. This method is not recursive.
|
|
461
|
+
|
|
462
|
+
hook (Callable[[torch.Tensor], None]) – User-defined all-reduce hook with expected signature hook(reduce_output: torch.Tensor) -> None where reduce_output is the reduce-scatter output if only using FSDP or the all-reduce output if using native HSDP.
|
|
463
|
+
|
|
464
|
+
stream (Optional[torch.cuda.Stream]) – Stream to run the all-reduce hook in. This should only be set if not using native HSDP. If using native HSDP, the hook will run in the internally defined all-reduce stream used by the native HSDP all-reduce.
|
|
465
|
+
|
|
466
|
+
Sets whether the temporary staging buffers used to send and receive data over collective communications should be allocated using the custom optimized allocator provided by the ProcessGroup itself (if any). This might allow the ProcessGroup to be more efficient. For example, when using NCCL, this enables it to leverage zero-copy transfers over SHARP (for NVLink and/or InfiniBand).
|
|
467
|
+
|
|
468
|
+
This cannot be used together with set_custom_all_gather() or set_custom_reduce_scatter() as those APIs allow for finer-grained control over each communication, and this method cannot determine their staging buffer allocation strategy.
|
|
469
|
+
|
|
470
|
+
enable (bool) – Whether to turn on ProcessGroup allocation.
|
|
471
|
+
|
|
472
|
+
Overrides the default all_gather communication behavior, to have better control over the communication and memory usage. See Comm and ReduceScatter for details.
|
|
473
|
+
|
|
474
|
+
comm (AllGather) – Custom all-gather communication.
|
|
475
|
+
|
|
476
|
+
Overrides the default reduce_scatter communication behavior, to have better control over the communication and memory usage. See Comm and ReduceScatter for details.
|
|
477
|
+
|
|
478
|
+
comm (ReduceScatter) – Custom reduce_scatter communication.
|
|
479
|
+
|
|
480
|
+
Sets whether to require the low-level collective communication primitives to exclusively use “sum”-type reductions, even if it comes at the cost of separate additional pre- or post-scaling operations. This is needed for example because NCCL currently supports zero-copy transfers only for this kind of collectives.
|
|
481
|
+
|
|
482
|
+
NB: for MTIA devices, this is always implicitly enabled.
|
|
483
|
+
|
|
484
|
+
NB: if set_all_reduce_hook is used under FSDP setup, the caller needs to ensure the custom all-reduce across FSDP units follow this strategy as well, as FSDP can no longer automatically handle that.
|
|
485
|
+
|
|
486
|
+
enable (bool) – Whether to only ever use ReduceOp.SUM for comms.
|
|
487
|
+
|
|
488
|
+
Sets a custom divide factor for the gradient reduction. This might use a custom reduce op using NCCL’s PreMulSum, which allows multiplying by the factor before reduction.
|
|
489
|
+
|
|
490
|
+
factor (float) – Custom divide factor.
|
|
491
|
+
|
|
492
|
+
Sets whether the next backward is the last one. On the last backward, FSDP waits on pending gradient reduction and clears internal data data structures for backward prefetching. This can be useful for microbatching.
|
|
493
|
+
|
|
494
|
+
Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in backward. This overrides the default backward pretching implementation that prefetches the next FSDP module based on the reverse post-forward order.
|
|
495
|
+
|
|
496
|
+
Passing a singleton list containing the previous FSDP module gives the same all-gather overlap behavior as the default overlap behavior. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory.
|
|
497
|
+
|
|
498
|
+
modules (List[FSDPModule]) – FSDP modules to prefetch.
|
|
499
|
+
|
|
500
|
+
Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in forward. The prefetching runs after this module’s all-gather copy-out.
|
|
501
|
+
|
|
502
|
+
Passing a singleton list containing the next FSDP module gives the same all-gather overlap behavior as the default overlap behavior, except the prefetched all-gather is issued earlier from the CPU. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory.
|
|
503
|
+
|
|
504
|
+
modules (List[FSDPModule]) – FSDP modules to prefetch.
|
|
505
|
+
|
|
506
|
+
Sets a post-optimizer-step event for the root FSDP module to wait the all-gather streams on.
|
|
507
|
+
|
|
508
|
+
By default, the root FSDP module waits the all-gather streams on the current stream to ensure that the optimizer step has finished before all-gathering. However, this may introduce false dependencies if there is unrelated computation after the optimizer step. This API allows the user to provide their own event to wait on. After the root waits on the event, the event is discarded, so this API should be called with a new event each iteration.
|
|
509
|
+
|
|
510
|
+
event (torch.Event) – Event recorded after the optimizer step to wait all-gather streams on.
|
|
511
|
+
|
|
512
|
+
Use set_gradient_divide_factor() instead
|
|
513
|
+
|
|
514
|
+
Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not all-reduce for HSDP.
|
|
515
|
+
|
|
516
|
+
Sets if the module should sync gradients. This can be used to implement gradient accumulation without communication. For HSDP, this controls both reduce-scatter and all-reduce together. This is the equivalence of no_sync in FSDP1.
|
|
517
|
+
|
|
518
|
+
requires_gradient_sync (bool) – Whether to reduce gradients for the module’s parameters.
|
|
519
|
+
|
|
520
|
+
recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module.
|
|
521
|
+
|
|
522
|
+
Sets if the module should reshard parameters after backward. This can be used during gradient accumulation to trade off higher memory for reduced communication since the unsharded parameters do not need to be re-all-gathered before the next forward.
|
|
523
|
+
|
|
524
|
+
reshard_after_backward (bool) – Whether to reshard parameters after backward.
|
|
525
|
+
|
|
526
|
+
recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module.
|
|
527
|
+
|
|
528
|
+
Sets if the module should reshard parameters after forward. This can be used to change the reshard_after_forward FSDP arg at runtime. For example, this can be used to set the FSDP root module’s value to True (since it is otherwise specially set to False), or it can set an FSDP module’s value to False for running evals and set back to True for training.
|
|
529
|
+
|
|
530
|
+
reshard_after_forward (bool) – Whether to reshard parameters after forward.
|
|
531
|
+
|
|
532
|
+
recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module.
|
|
533
|
+
|
|
534
|
+
Sets whether the FSDP module’s parameters need to be unsharded in backward. This can be used in expert cases when the user knows that all parameters in this FSDP module’s parameter group are not needed for backward computation (e.g. embedding).
|
|
535
|
+
|
|
536
|
+
Unshards the module’s parameters by allocating memory and all-gathering the parameters. This method is not recursive. The unshard follows the MixedPrecisionPolicy, so it will all-gather following param_dtype if set.
|
|
537
|
+
|
|
538
|
+
async_op (bool) – If True, then returns a UnshardHandle that has a wait() method to wait on the unshard op. If False, then returns None and waits on the handle inside this function.
|
|
539
|
+
|
|
540
|
+
Optional[UnshardHandle]
|
|
541
|
+
|
|
542
|
+
If async_op=True, then FSDP will wait on the pending unshard in the module’s pre-forward for the user. The user only needs to call wait() explicitly if the wait should happen before pre-forward.
|
|
543
|
+
|
|
544
|
+
A handle to wait on a FSDPModule.unshard() op.
|
|
545
|
+
|
|
546
|
+
Waits on the unshard op. This ensures that the current stream can use the unsharded parameters, which are now registered to the module.
|
|
547
|
+
|
|
548
|
+
Registers a method on module to be considered a forward method for FSDP.
|
|
549
|
+
|
|
550
|
+
FSDP all-gathers parameters pre-forward and optionally frees parameters post-forward (depending on reshard_after_forward). FSDP only knows to do this for nn.Module.forward() by default. This function patches a user-specified method to run the pre/post-forward hooks before/after the method, respectively. If module is not an FSDPModule, then this is a no-op.
|
|
551
|
+
|
|
552
|
+
module (nn.Module) – Module to register the forward method on.
|
|
553
|
+
|
|
554
|
+
method_name (str) – Name of the forward method.
|
|
555
|
+
|
|
556
|
+
This configures FSDP’s mixed precision. Unlike autocast, this applies mixed precision at the module level, not op level, which means low-precision activations are saved for backward and high-to-low-precision casts are incurred only at module boundaries.
|
|
557
|
+
|
|
558
|
+
FSDP works well with module-level mixed precision since it keeps the high-precision sharded parameters in memory anyway. In other words, FSDP does not require any extra memory to keep a high-precision copy of the parameters for the optimizer step.
|
|
559
|
+
|
|
560
|
+
param_dtype (Optional[torch.dtype]) – This specifies the dtype for the unsharded parameter and hence the dtype for forward/backward computation and the parameter all-gather. If this is None, then the unsharded parameter uses the original dtype. The optimizer step uses the sharded parameter in the original dtype. (Default: None)
|
|
561
|
+
|
|
562
|
+
reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is None but param_dtype is not None, then the reduction uses the compute dtype. This can be used to run gradient reduction in full precision while using low precision for compute. If also gradient reduction is disabled via set_requires_gradient_sync(), then FSDP will accumulate gradients using reduce_dtype. (Default: None)
|
|
563
|
+
|
|
564
|
+
output_dtype (Optional[torch.dtype]) – This specifies the dtype for casting floating-point forward outputs. This can be used to help implement cases where different modules have different mixed precision policies. (Default: None)
|
|
565
|
+
|
|
566
|
+
cast_forward_inputs (bool) – This specifies whether FSDP should cast the forward’s floating-point input tensors to param_dtype or not.
|
|
567
|
+
|
|
568
|
+
This base class represents the policy of no offloading and is only used as the default value for the offload_policy arg.
|
|
569
|
+
|
|
570
|
+
This offload policy offloads parameters, gradients, and optimizer states to CPU. Sharded parameters are copied host-to-device before all-gather. The all-gathered parameters are freed according to reshard_after_forward. Sharded gradients are copied device-to-host in backward, and the optimizer step runs on CPU with CPU optimizer states.
|
|
571
|
+
|
|
572
|
+
pin_memory (bool) – Whether to pin sharded parameter and gradient memory. Pinning memory allows both more efficient H2D/D2H copies and for the copies to overlap with compute. However, the pinned memory cannot be used by other processes. Set this to False if you have insufficient CPU memory. (Default: True)
|
|
573
|
+
|
|
574
|
+
---
|
|
575
|
+
|
|
576
|
+
## Distributed communication package - torch.distributed#
|
|
577
|
+
|
|
578
|
+
**URL:** https://pytorch.org/docs/stable/distributed.html
|
|
579
|
+
|
|
580
|
+
**Contents:**
|
|
581
|
+
- Distributed communication package - torch.distributed#
|
|
582
|
+
- Backends#
|
|
583
|
+
- Backends that come with PyTorch#
|
|
584
|
+
- Which backend to use?#
|
|
585
|
+
- Common environment variables#
|
|
586
|
+
- Choosing the network interface to use#
|
|
587
|
+
- Other NCCL environment variables#
|
|
588
|
+
- Basics#
|
|
589
|
+
- Initialization#
|
|
590
|
+
- TCP initialization#
|
|
591
|
+
|
|
592
|
+
Created On: Jul 12, 2017 | Last Updated On: Sep 04, 2025
|
|
593
|
+
|
|
594
|
+
Please refer to PyTorch Distributed Overview for a brief introduction to all features related to distributed training.
|
|
595
|
+
|
|
596
|
+
torch.distributed supports four built-in backends, each with different capabilities. The table below shows which functions are available for use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPU while for XCCL to XPU GPU.
|
|
597
|
+
|
|
598
|
+
MPI supports CUDA only if the implementation used to build PyTorch supports it.
|
|
599
|
+
|
|
600
|
+
PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). By default for Linux, the Gloo and NCCL backends are built and included in PyTorch distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be included if you build PyTorch from source. (e.g. building PyTorch on a host that has MPI installed.)
|
|
601
|
+
|
|
602
|
+
As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, If the init_method argument of init_process_group() points to a file it must adhere to the following schema:
|
|
603
|
+
|
|
604
|
+
Local file system, init_method="file:///d:/tmp/some_file"
|
|
605
|
+
|
|
606
|
+
Shared file system, init_method="file://////{machine_name}/{share_folder_name}/some_file"
|
|
607
|
+
|
|
608
|
+
Same as on Linux platform, you can enable TcpStore by setting environment variables, MASTER_ADDR and MASTER_PORT.
|
|
609
|
+
|
|
610
|
+
In the past, we were often asked: “which backend should I use?”.
|
|
611
|
+
|
|
612
|
+
Use the NCCL backend for distributed training with CUDA GPU.
|
|
613
|
+
|
|
614
|
+
Use the XCCL backend for distributed training with XPU GPU.
|
|
615
|
+
|
|
616
|
+
Use the Gloo backend for distributed training with CPU.
|
|
617
|
+
|
|
618
|
+
GPU hosts with InfiniBand interconnect
|
|
619
|
+
|
|
620
|
+
Use NCCL, since it’s the only backend that currently supports InfiniBand and GPUDirect.
|
|
621
|
+
|
|
622
|
+
GPU hosts with Ethernet interconnect
|
|
623
|
+
|
|
624
|
+
Use NCCL, since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training. If you encounter any problem with NCCL, use Gloo as the fallback option. (Note that Gloo currently runs slower than NCCL for GPUs.)
|
|
625
|
+
|
|
626
|
+
CPU hosts with InfiniBand interconnect
|
|
627
|
+
|
|
628
|
+
If your InfiniBand has enabled IP over IB, use Gloo, otherwise, use MPI instead. We are planning on adding InfiniBand support for Gloo in the upcoming releases.
|
|
629
|
+
|
|
630
|
+
CPU hosts with Ethernet interconnect
|
|
631
|
+
|
|
632
|
+
Use Gloo, unless you have specific reasons to use MPI.
|
|
633
|
+
|
|
634
|
+
By default, both the NCCL and Gloo backends will try to find the right network interface to use. If the automatically detected interface is not correct, you can override it using the following environment variables (applicable to the respective backend):
|
|
635
|
+
|
|
636
|
+
NCCL_SOCKET_IFNAME, for example export NCCL_SOCKET_IFNAME=eth0
|
|
637
|
+
|
|
638
|
+
GLOO_SOCKET_IFNAME, for example export GLOO_SOCKET_IFNAME=eth0
|
|
639
|
+
|
|
640
|
+
If you’re using the Gloo backend, you can specify multiple interfaces by separating them by a comma, like this: export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3. The backend will dispatch operations in a round-robin fashion across these interfaces. It is imperative that all processes specify the same number of interfaces in this variable.
|
|
641
|
+
|
|
642
|
+
Debugging - in case of NCCL failure, you can set NCCL_DEBUG=INFO to print an explicit warning message as well as basic NCCL initialization information.
|
|
643
|
+
|
|
644
|
+
You may also use NCCL_DEBUG_SUBSYS to get more details about a specific aspect of NCCL. For example, NCCL_DEBUG_SUBSYS=COLL would print logs of collective calls, which may be helpful when debugging hangs, especially those caused by collective type or message size mismatch. In case of topology detection failure, it would be helpful to set NCCL_DEBUG_SUBSYS=GRAPH to inspect the detailed detection result and save as reference if further help from NCCL team is needed.
|
|
645
|
+
|
|
646
|
+
Performance tuning - NCCL performs automatic tuning based on its topology detection to save users’ tuning effort. On some socket-based systems, users may still try tuning NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD to increase socket network bandwidth. These two environment variables have been pre-tuned by NCCL for some cloud providers, such as AWS or GCP.
|
|
647
|
+
|
|
648
|
+
For a full list of NCCL environment variables, please refer to NVIDIA NCCL’s official documentation
|
|
649
|
+
|
|
650
|
+
You can tune NCCL communicators even further using torch.distributed.ProcessGroupNCCL.NCCLConfig and torch.distributed.ProcessGroupNCCL.Options. Learn more about them using help (e.g. help(torch.distributed.ProcessGroupNCCL.NCCLConfig)) in the interpreter.
|
|
651
|
+
|
|
652
|
+
The torch.distributed package provides PyTorch support and communication primitives for multiprocess parallelism across several computation nodes running on one or more machines. The class torch.nn.parallel.DistributedDataParallel() builds on this functionality to provide synchronous distributed training as a wrapper around any PyTorch model. This differs from the kinds of parallelism provided by Multiprocessing package - torch.multiprocessing and torch.nn.DataParallel() in that it supports multiple network-connected machines and in that the user must explicitly launch a separate copy of the main training script for each process.
|
|
653
|
+
|
|
654
|
+
In the single-machine synchronous case, torch.distributed or the torch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel():
|
|
655
|
+
|
|
656
|
+
Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes.
|
|
657
|
+
|
|
658
|
+
Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.
|
|
659
|
+
|
|
660
|
+
The package needs to be initialized using the torch.distributed.init_process_group() or torch.distributed.device_mesh.init_device_mesh() function before calling any other methods. Both block until all processes have joined.
|
|
661
|
+
|
|
662
|
+
Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent inconsistent ‘UUID’ assignment across ranks, and to prevent races during initialization that can lead to hangs.
|
|
663
|
+
|
|
664
|
+
Return True if the distributed package is available.
|
|
665
|
+
|
|
666
|
+
Otherwise, torch.distributed does not expose any other APIs. Currently, torch.distributed is available on Linux, MacOS and Windows. Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source. Currently, the default value is USE_DISTRIBUTED=1 for Linux and Windows, USE_DISTRIBUTED=0 for MacOS.
|
|
667
|
+
|
|
668
|
+
Initialize the default distributed process group.
|
|
669
|
+
|
|
670
|
+
This will also initialize the distributed package.
|
|
671
|
+
|
|
672
|
+
Specify store, rank, and world_size explicitly.
|
|
673
|
+
|
|
674
|
+
Specify init_method (a URL string) which indicates where/how to discover peers. Optionally specify rank and world_size, or encode all required parameters in the URL and omit them.
|
|
675
|
+
|
|
676
|
+
If neither is specified, init_method is assumed to be “env://”.
|
|
677
|
+
|
|
678
|
+
backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values include mpi, gloo, nccl, ucc, xccl or one that is registered by a third-party plugin. Since 2.6, if backend is not provided, c10d will use a backend registered for the device type indicated by the device_id kwarg (if provided). The known default registrations today are: nccl for cuda, gloo for cpu, xccl for xpu. If neither backend nor device_id is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or cpu). This field can be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If using multiple processes per machine with nccl backend, each process must have exclusive access to every GPU it uses, as sharing GPUs between processes can result in deadlock or NCCL invalid usage. ucc backend is experimental. Default backend for the device can be queried with get_default_backend_for_device().
|
|
679
|
+
|
|
680
|
+
init_method (str, optional) – URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. Mutually exclusive with store.
|
|
681
|
+
|
|
682
|
+
world_size (int, optional) – Number of processes participating in the job. Required if store is specified.
|
|
683
|
+
|
|
684
|
+
rank (int, optional) – Rank of the current process (it should be a number between 0 and world_size-1). Required if store is specified.
|
|
685
|
+
|
|
686
|
+
store (Store, optional) – Key/value store accessible to all workers, used to exchange connection/address information. Mutually exclusive with init_method.
|
|
687
|
+
|
|
688
|
+
timeout (timedelta, optional) – Timeout for operations executed against the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. This is the duration after which collectives will be aborted asynchronously and the process will crash. This is done since CUDA execution is async and it is no longer safe to continue executing user code since failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
|
|
689
|
+
|
|
690
|
+
group_name (str, optional, deprecated) – Group name. This argument is ignored
|
|
691
|
+
|
|
692
|
+
pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. As of now, the only options we support is ProcessGroupNCCL.Options for the nccl backend, is_high_priority_stream can be specified so that the nccl backend can pick up high priority cuda streams when there’re compute kernels waiting. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
|
|
693
|
+
|
|
694
|
+
device_id (torch.device | int, optional) – a single, specific device this process will work on, allowing for backend-specific optimizations. Currently this has two effects, only under NCCL: the communicator is immediately formed (calling ncclCommInit* immediately rather than the normal lazy call) and sub-groups will use ncclCommSplit when possible to avoid unnecessary overhead of group creation. If you want to know NCCL initialization error early, you can also use this field. If an int is provided, the API assumes that the accelerator type at compile time will be used.
|
|
695
|
+
|
|
696
|
+
To enable backend == Backend.MPI, PyTorch needs to be built from source on a system that supports MPI.
|
|
697
|
+
|
|
698
|
+
Support for multiple backends is experimental. Currently when no backend is specified, both gloo and nccl backends will be created. The gloo backend will be used for collectives with CPU tensors and the nccl backend will be used for collectives with CUDA tensors. A custom backend can be specified by passing in a string with format “<device_type>:<backend_name>,<device_type>:<backend_name>”, e.g. “cpu:gloo,cuda:custom_backend”.
|
|
699
|
+
|
|
700
|
+
Initializes a DeviceMesh based on device_type, mesh_shape, and mesh_dim_names parameters.
|
|
701
|
+
|
|
702
|
+
This creates a DeviceMesh with an n-dimensional array layout, where n is the length of mesh_shape. If mesh_dim_names is provided, each dimension is labeled as mesh_dim_names[i].
|
|
703
|
+
|
|
704
|
+
init_device_mesh follows SPMD programming model, meaning the same PyTorch Python program runs on all processes/ranks in the cluster. Ensure mesh_shape (the dimensions of the nD array describing device layout) is identical across all ranks. Inconsistent mesh_shape may lead to hanging.
|
|
705
|
+
|
|
706
|
+
If no process group is found, init_device_mesh will initialize distributed process group/groups required for distributed communications behind the scene.
|
|
707
|
+
|
|
708
|
+
device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”, “xpu”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed.
|
|
709
|
+
|
|
710
|
+
mesh_shape (Tuple[int]) – A tuple defining the dimensions of the multi-dimensional array describing the layout of devices.
|
|
711
|
+
|
|
712
|
+
mesh_dim_names (Tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique.
|
|
713
|
+
|
|
714
|
+
backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional) – Overrides for some or all of the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name of the backend and its options, or just one of these two components (in which case the other will be set to its default value).
|
|
715
|
+
|
|
716
|
+
A DeviceMesh object representing the device layout.
|
|
717
|
+
|
|
718
|
+
Check if the default process group has been initialized.
|
|
719
|
+
|
|
720
|
+
Check if the MPI backend is available.
|
|
721
|
+
|
|
722
|
+
Check if the NCCL backend is available.
|
|
723
|
+
|
|
724
|
+
Check if the Gloo backend is available.
|
|
725
|
+
|
|
726
|
+
Check if the XCCL backend is available.
|
|
727
|
+
|
|
728
|
+
Check whether this process was launched with torch.distributed.elastic (aka torchelastic).
|
|
729
|
+
|
|
730
|
+
The existence of TORCHELASTIC_RUN_ID environment variable is used as a proxy to determine whether the current process was launched with torchelastic. This is a reasonable proxy since TORCHELASTIC_RUN_ID maps to the rendezvous id which is always a non-null value indicating the job id for peer discovery purposes..
|
|
731
|
+
|
|
732
|
+
Return the default backend for the given device.
|
|
733
|
+
|
|
734
|
+
device (Union[str, torch.device]) – The device to get the default backend for.
|
|
735
|
+
|
|
736
|
+
The default backend for the given device as a lower case string.
|
|
737
|
+
|
|
738
|
+
Currently three initialization methods are supported:
|
|
739
|
+
|
|
740
|
+
There are two ways to initialize using TCP, both requiring a network address reachable from all processes and a desired world_size. The first way requires specifying an address that belongs to the rank 0 process. This initialization method requires that all processes have manually specified ranks.
|
|
741
|
+
|
|
742
|
+
Note that multicast address is not supported anymore in the latest distributed package. group_name is deprecated as well.
|
|
743
|
+
|
|
744
|
+
Another initialization method makes use of a file system that is shared and visible from all machines in a group, along with a desired world_size. The URL should start with file:// and contain a path to a non-existent file (in an existing directory) on a shared file system. File-system initialization will automatically create that file if it doesn’t exist, but will not delete the file. Therefore, it is your responsibility to make sure that the file is cleaned up before the next init_process_group() call on the same file path/name.
|
|
745
|
+
|
|
746
|
+
Note that automatic rank assignment is not supported anymore in the latest distributed package and group_name is deprecated as well.
|
|
747
|
+
|
|
748
|
+
This method assumes that the file system supports locking using fcntl - most local systems and NFS support it.
|
|
749
|
+
|
|
750
|
+
This method will always create the file and try its best to clean up and remove the file at the end of the program. In other words, each initialization with the file init method will need a brand new empty file in order for the initialization to succeed. If the same file used by the previous initialization (which happens not to get cleaned up) is used again, this is unexpected behavior and can often cause deadlocks and failures. Therefore, even though this method will try its best to clean up the file, if the auto-delete happens to be unsuccessful, it is your responsibility to ensure that the file is removed at the end of the training to prevent the same file to be reused again during the next time. This is especially important if you plan to call init_process_group() multiple times on the same file name. In other words, if the file is not removed/cleaned up and you call init_process_group() again on that file, failures are expected. The rule of thumb here is that, make sure that the file is non-existent or empty every time init_process_group() is called.
|
|
751
|
+
|
|
752
|
+
This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are:
|
|
753
|
+
|
|
754
|
+
MASTER_PORT - required; has to be a free port on machine with rank 0
|
|
755
|
+
|
|
756
|
+
MASTER_ADDR - required (except for rank 0); address of rank 0 node
|
|
757
|
+
|
|
758
|
+
WORLD_SIZE - required; can be set either here, or in a call to init function
|
|
759
|
+
|
|
760
|
+
RANK - required; can be set either here, or in a call to init function
|
|
761
|
+
|
|
762
|
+
The machine with rank 0 will be used to set up all connections.
|
|
763
|
+
|
|
764
|
+
This is the default method, meaning that init_method does not have to be specified (or can be env://).
|
|
765
|
+
|
|
766
|
+
TORCH_GLOO_LAZY_INIT - establishes connections on demand rather than using a full mesh which can greatly improve initialization time for non all2all operations.
|
|
767
|
+
|
|
768
|
+
Once torch.distributed.init_process_group() was run, the following functions can be used. To check whether the process group has already been initialized use torch.distributed.is_initialized().
|
|
769
|
+
|
|
770
|
+
An enum-like class for backends.
|
|
771
|
+
|
|
772
|
+
Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends.
|
|
773
|
+
|
|
774
|
+
The values of this class are lowercase strings, e.g., "gloo". They can be accessed as attributes, e.g., Backend.NCCL.
|
|
775
|
+
|
|
776
|
+
This class can be directly called to parse the string, e.g., Backend(backend_str) will check if backend_str is valid, and return the parsed lowercase string if so. It also accepts uppercase strings, e.g., Backend("GLOO") returns "gloo".
|
|
777
|
+
|
|
778
|
+
The entry Backend.UNDEFINED is present but only used as initial value of some fields. Users should neither use it directly nor assume its existence.
|
|
779
|
+
|
|
780
|
+
Register a new backend with the given name and instantiating function.
|
|
781
|
+
|
|
782
|
+
This class method is used by 3rd party ProcessGroup extension to register new backends.
|
|
783
|
+
|
|
784
|
+
name (str) – Backend name of the ProcessGroup extension. It should match the one in init_process_group().
|
|
785
|
+
|
|
786
|
+
func (function) – Function handler that instantiates the backend. The function should be implemented in the backend extension and takes four arguments, including store, rank, world_size, and timeout.
|
|
787
|
+
|
|
788
|
+
extended_api (bool, optional) – Whether the backend supports extended argument structure. Default: False. If set to True, the backend will get an instance of c10d::DistributedBackendOptions, and a process group options object as defined by the backend implementation.
|
|
789
|
+
|
|
790
|
+
device (str or list of str, optional) – device type this backend supports, e.g. “cpu”, “cuda”, etc. If None, assuming both “cpu” and “cuda”
|
|
791
|
+
|
|
792
|
+
This support of 3rd party backend is experimental and subject to change.
|
|
793
|
+
|
|
794
|
+
Return the backend of the given process group.
|
|
795
|
+
|
|
796
|
+
group (ProcessGroup, optional) – The process group to work on. The default is the general main process group. If another specific group is specified, the calling process must be part of group.
|
|
797
|
+
|
|
798
|
+
The backend of the given process group as a lower case string.
|
|
799
|
+
|
|
800
|
+
Return the rank of the current process in the provided group, default otherwise.
|
|
801
|
+
|
|
802
|
+
Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size.
|
|
803
|
+
|
|
804
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
805
|
+
|
|
806
|
+
The rank of the process group -1, if not part of the group
|
|
807
|
+
|
|
808
|
+
Return the number of processes in the current process group.
|
|
809
|
+
|
|
810
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
811
|
+
|
|
812
|
+
The world size of the process group -1, if not part of the group
|
|
813
|
+
|
|
814
|
+
It is important to clean up resources on exit by calling destroy_process_group().
|
|
815
|
+
|
|
816
|
+
The simplest pattern to follow is to destroy every process group and backend by calling destroy_process_group() with the default value of None for the group argument, at a point in the training script where communications are no longer needed, usually near the end of main(). The call should be made once per trainer-process, not at the outer process-launcher level.
|
|
817
|
+
|
|
818
|
+
if destroy_process_group() is not called by all ranks in a pg within the timeout duration, especially when there are multiple process-groups in the application e.g. for N-D parallelism, hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort, which must be called collectively, but the order of calling ProcessGroupNCCL’s destructor if called by python’s GC is not deterministic. Calling destroy_process_group() helps by ensuring ncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbort during ProcessGroupNCCL’s destructor.
|
|
819
|
+
|
|
820
|
+
destroy_process_group can also be used to destroy individual process groups. One use case could be fault tolerant training, where a process group may be destroyed and then a new one initialized during runtime. In this case, it’s critical to synchronize the trainer processes using some means other than torch.distributed primitives _after_ calling destroy and before subsequently initializing. This behavior is currently unsupported/untested, due to the difficulty of achieving this synchronization, and is considered a known issue. Please file a github issue or RFC if this is a use case that’s blocking you.
|
|
821
|
+
|
|
822
|
+
By default collectives operate on the default group (also called the world) and require all processes to enter the distributed function call. However, some workloads can benefit from more fine-grained communication. This is where distributed groups come into play. new_group() function can be used to create new groups, with arbitrary subsets of all processes. It returns an opaque group handle that can be given as a group argument to all collectives (collectives are distributed functions to exchange information in certain well-known programming patterns).
|
|
823
|
+
|
|
824
|
+
Create a new distributed group.
|
|
825
|
+
|
|
826
|
+
This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes.
|
|
827
|
+
|
|
828
|
+
Safe concurrent usage: When using multiple process groups with the NCCL backend, the user must ensure a globally consistent execution order of collectives across ranks.
|
|
829
|
+
|
|
830
|
+
If multiple threads within a process issue collectives, explicit synchronization is necessary to ensure consistent ordering.
|
|
831
|
+
|
|
832
|
+
When using async variants of torch.distributed communication APIs, a work object is returned and the communication kernel is enqueued on a separate CUDA stream, allowing overlap of communication and computation. Once one or more async ops have been issued on one process group, they must be synchronized with other cuda streams by calling work.wait() before using another process group.
|
|
833
|
+
|
|
834
|
+
See Using multiple NCCL communicators concurrently <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently> for more details.
|
|
835
|
+
|
|
836
|
+
ranks (list[int]) – List of ranks of group members. If None, will be set to all ranks. Default is None.
|
|
837
|
+
|
|
838
|
+
timeout (timedelta, optional) – see init_process_group for details and default value.
|
|
839
|
+
|
|
840
|
+
backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values are gloo and nccl. By default uses the same backend as the global group. This field should be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If None is passed in, the backend corresponding to the default process group will be used. Default is None.
|
|
841
|
+
|
|
842
|
+
pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. i.e. for the nccl backend, is_high_priority_stream can be specified so that process group can pick up high priority cuda streams. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization (bool, optional): perform a group-local barrier at the end of the process group creation. This is different in that non-member ranks don’t need to call into API and don’t join the barrier.
|
|
843
|
+
|
|
844
|
+
group_desc (str, optional) – a string to describe the process group.
|
|
845
|
+
|
|
846
|
+
device_id (torch.device, optional) – a single, specific device to “bind” this process to, The new_group call will try to initialize a communication backend immediately for the device if this field is given.
|
|
847
|
+
|
|
848
|
+
A handle of distributed group that can be given to collective calls or GroupMember.NON_GROUP_MEMBER if the rank is not part of ranks.
|
|
849
|
+
|
|
850
|
+
N.B. use_local_synchronization doesn’t work with MPI.
|
|
851
|
+
|
|
852
|
+
N.B. While use_local_synchronization=True can be significantly faster with larger clusters and small process groups, care must be taken since it changes cluster behavior as non-member ranks don’t join the group barrier().
|
|
853
|
+
|
|
854
|
+
N.B. use_local_synchronization=True can lead to deadlocks when each rank creates multiple overlapping process groups. To avoid that, make sure all ranks follow the same global creation order.
|
|
855
|
+
|
|
856
|
+
Translate a global rank into a group rank.
|
|
857
|
+
|
|
858
|
+
global_rank must be part of group otherwise this raises RuntimeError.
|
|
859
|
+
|
|
860
|
+
group (ProcessGroup) – ProcessGroup to find the relative rank.
|
|
861
|
+
|
|
862
|
+
global_rank (int) – Global rank to query.
|
|
863
|
+
|
|
864
|
+
Group rank of global_rank relative to group
|
|
865
|
+
|
|
866
|
+
N.B. calling this function on the default process group returns identity
|
|
867
|
+
|
|
868
|
+
Translate a group rank into a global rank.
|
|
869
|
+
|
|
870
|
+
group_rank must be part of group otherwise this raises RuntimeError.
|
|
871
|
+
|
|
872
|
+
group (ProcessGroup) – ProcessGroup to find the global rank from.
|
|
873
|
+
|
|
874
|
+
group_rank (int) – Group rank to query.
|
|
875
|
+
|
|
876
|
+
Global rank of group_rank relative to group
|
|
877
|
+
|
|
878
|
+
N.B. calling this function on the default process group returns identity
|
|
879
|
+
|
|
880
|
+
Get all ranks associated with group.
|
|
881
|
+
|
|
882
|
+
group (Optional[ProcessGroup]) – ProcessGroup to get all ranks from. If None, the default process group will be used.
|
|
883
|
+
|
|
884
|
+
List of global ranks ordered by group rank.
|
|
885
|
+
|
|
886
|
+
DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators). It allows user to easily create inter node and intra node process groups without worrying about how to set up the ranks correctly for different sub process groups, and it helps manage those distributed process group easily. init_device_mesh() function can be used to create new DeviceMesh, with a mesh shape describing the device topology.
|
|
887
|
+
|
|
888
|
+
DeviceMesh represents a mesh of devices, where layout of devices could be represented as a n-d dimension array, and each value of the n-d dimensional array is the global id of the default process group ranks.
|
|
889
|
+
|
|
890
|
+
DeviceMesh could be used to setup the N dimensional device connections across the cluster, and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects already (i.e. if user call torch.cuda.set_device before the DeviceMesh initialization), and will select/set the device for the current process if user does not set the device beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
|
|
891
|
+
|
|
892
|
+
DeviceMesh can also be used as a context manager when using together with DTensor APIs.
|
|
893
|
+
|
|
894
|
+
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program is running on all processes/ranks in the cluster. Therefore, users need to make sure the mesh array (which describes the layout of devices) should be identical across all ranks. Inconsistent mesh will lead to silent hang.
|
|
895
|
+
|
|
896
|
+
device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”.
|
|
897
|
+
|
|
898
|
+
mesh (ndarray) – A multi-dimensional array or an integer tensor describing the layout of devices, where the IDs are global IDs of the default process group.
|
|
899
|
+
|
|
900
|
+
A DeviceMesh object representing the device layout.
|
|
901
|
+
|
|
902
|
+
The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. A reduction over the first dimension of mesh will reduce across columns (0, 4), .. and (3, 7), a reduction over the second dimension of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
|
|
903
|
+
|
|
904
|
+
Constructs a DeviceMesh with device_type from an existing ProcessGroup or a list of existing ProcessGroup.
|
|
905
|
+
|
|
906
|
+
The constructed device mesh has number of dimensions equal to the number of groups passed. For example, if a single process group is passed in, the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, the resulted DeviceMesh is a 2D mesh.
|
|
907
|
+
|
|
908
|
+
If more than one group is passed, then the mesh and mesh_dim_names arguments are required. The order of the process groups passed in determines the topology of the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. The mesh tensor passed in must have the same number of dimensions as the number of process groups passed in, and the order of the dimensions in the mesh tensor must match the order in the process groups passed in.
|
|
909
|
+
|
|
910
|
+
group (ProcessGroup or list[ProcessGroup]) – the existing ProcessGroup or a list of existing ProcessGroups.
|
|
911
|
+
|
|
912
|
+
device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed.
|
|
913
|
+
|
|
914
|
+
mesh (torch.Tensor or ArrayLike, optional) – A multi-dimensional array or an integer tensor describing the layout of devices, where the IDs are global IDs of the default process group. Default is None.
|
|
915
|
+
|
|
916
|
+
mesh_dim_names (tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique. Default is None.
|
|
917
|
+
|
|
918
|
+
A DeviceMesh object representing the device layout.
|
|
919
|
+
|
|
920
|
+
Returns a list of ProcessGroups for all mesh dimensions.
|
|
921
|
+
|
|
922
|
+
A list of ProcessGroup object.
|
|
923
|
+
|
|
924
|
+
list[torch.distributed.distributed_c10d.ProcessGroup]
|
|
925
|
+
|
|
926
|
+
Return the relative indices of this rank relative to all dimensions of the mesh. If this rank is not part of the mesh, return None.
|
|
927
|
+
|
|
928
|
+
Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
|
|
929
|
+
|
|
930
|
+
mesh_dim (str/python:int, optional) – it can be the name of the mesh dimension or the index
|
|
931
|
+
|
|
932
|
+
None. (of the mesh dimension. Default is) –
|
|
933
|
+
|
|
934
|
+
A ProcessGroup object.
|
|
935
|
+
|
|
936
|
+
Returns the local rank of the given mesh_dim of the DeviceMesh.
|
|
937
|
+
|
|
938
|
+
mesh_dim (str/python:int, optional) – it can be the name of the mesh dimension or the index
|
|
939
|
+
|
|
940
|
+
None. (of the mesh dimension. Default is) –
|
|
941
|
+
|
|
942
|
+
An integer denotes the local rank.
|
|
943
|
+
|
|
944
|
+
The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
|
|
945
|
+
|
|
946
|
+
Returns the current global rank.
|
|
947
|
+
|
|
948
|
+
Send a tensor synchronously.
|
|
949
|
+
|
|
950
|
+
tag is not supported with the NCCL backend.
|
|
951
|
+
|
|
952
|
+
tensor (Tensor) – Tensor to send.
|
|
953
|
+
|
|
954
|
+
dst (int) – Destination rank on global process group (regardless of group argument). Destination rank should not be the same as the rank of the current process.
|
|
955
|
+
|
|
956
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
957
|
+
|
|
958
|
+
tag (int, optional) – Tag to match send with remote recv
|
|
959
|
+
|
|
960
|
+
group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst.
|
|
961
|
+
|
|
962
|
+
Receives a tensor synchronously.
|
|
963
|
+
|
|
964
|
+
tag is not supported with the NCCL backend.
|
|
965
|
+
|
|
966
|
+
tensor (Tensor) – Tensor to fill with received data.
|
|
967
|
+
|
|
968
|
+
src (int, optional) – Source rank on global process group (regardless of group argument). Will receive from any process if unspecified.
|
|
969
|
+
|
|
970
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
971
|
+
|
|
972
|
+
tag (int, optional) – Tag to match recv with remote send
|
|
973
|
+
|
|
974
|
+
group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src.
|
|
975
|
+
|
|
976
|
+
Sender rank -1, if not part of the group
|
|
977
|
+
|
|
978
|
+
isend() and irecv() return distributed request objects when used. In general, the type of this object is unspecified as they should never be created manually, but they are guaranteed to support two methods:
|
|
979
|
+
|
|
980
|
+
is_completed() - returns True if the operation has finished
|
|
981
|
+
|
|
982
|
+
wait() - will block the process until the operation is finished. is_completed() is guaranteed to return True once it returns.
|
|
983
|
+
|
|
984
|
+
Send a tensor asynchronously.
|
|
985
|
+
|
|
986
|
+
Modifying tensor before the request completes causes undefined behavior.
|
|
987
|
+
|
|
988
|
+
tag is not supported with the NCCL backend.
|
|
989
|
+
|
|
990
|
+
Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.
|
|
991
|
+
|
|
992
|
+
tensor (Tensor) – Tensor to send.
|
|
993
|
+
|
|
994
|
+
dst (int) – Destination rank on global process group (regardless of group argument)
|
|
995
|
+
|
|
996
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
997
|
+
|
|
998
|
+
tag (int, optional) – Tag to match send with remote recv
|
|
999
|
+
|
|
1000
|
+
group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst
|
|
1001
|
+
|
|
1002
|
+
A distributed request object. None, if not part of the group
|
|
1003
|
+
|
|
1004
|
+
Receives a tensor asynchronously.
|
|
1005
|
+
|
|
1006
|
+
tag is not supported with the NCCL backend.
|
|
1007
|
+
|
|
1008
|
+
Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.
|
|
1009
|
+
|
|
1010
|
+
tensor (Tensor) – Tensor to fill with received data.
|
|
1011
|
+
|
|
1012
|
+
src (int, optional) – Source rank on global process group (regardless of group argument). Will receive from any process if unspecified.
|
|
1013
|
+
|
|
1014
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1015
|
+
|
|
1016
|
+
tag (int, optional) – Tag to match recv with remote send
|
|
1017
|
+
|
|
1018
|
+
group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src.
|
|
1019
|
+
|
|
1020
|
+
A distributed request object. None, if not part of the group
|
|
1021
|
+
|
|
1022
|
+
Sends picklable objects in object_list synchronously.
|
|
1023
|
+
|
|
1024
|
+
Similar to send(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be sent.
|
|
1025
|
+
|
|
1026
|
+
object_list (List[Any]) – List of input objects to sent. Each object must be picklable. Receiver must provide lists of equal sizes.
|
|
1027
|
+
|
|
1028
|
+
dst (int) – Destination rank to send object_list to. Destination rank is based on global process group (regardless of group argument)
|
|
1029
|
+
|
|
1030
|
+
group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
|
|
1031
|
+
|
|
1032
|
+
device (torch.device, optional) – If not None, the objects are serialized and converted to tensors which are moved to the device before sending. Default is None.
|
|
1033
|
+
|
|
1034
|
+
group_dst (int, optional) – Destination rank on group. Must specify one of dst and group_dst but not both
|
|
1035
|
+
|
|
1036
|
+
use_batch (bool, optional) – If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is False.
|
|
1037
|
+
|
|
1038
|
+
For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
1039
|
+
|
|
1040
|
+
Object collectives have a number of serious performance and scalability limitations. See Object collectives for details.
|
|
1041
|
+
|
|
1042
|
+
send_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.
|
|
1043
|
+
|
|
1044
|
+
Calling send_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using send() instead.
|
|
1045
|
+
|
|
1046
|
+
Receives picklable objects in object_list synchronously.
|
|
1047
|
+
|
|
1048
|
+
Similar to recv(), but can receive Python objects.
|
|
1049
|
+
|
|
1050
|
+
object_list (List[Any]) – List of objects to receive into. Must provide a list of sizes equal to the size of the list being sent.
|
|
1051
|
+
|
|
1052
|
+
src (int, optional) – Source rank from which to recv object_list. Source rank is based on global process group (regardless of group argument) Will receive from any rank if set to None. Default is None.
|
|
1053
|
+
|
|
1054
|
+
group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
|
|
1055
|
+
|
|
1056
|
+
device (torch.device, optional) – If not None, receives on this device. Default is None.
|
|
1057
|
+
|
|
1058
|
+
group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src.
|
|
1059
|
+
|
|
1060
|
+
use_batch (bool, optional) – If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is False.
|
|
1061
|
+
|
|
1062
|
+
Sender rank. -1 if rank is not part of the group. If rank is part of the group, object_list will contain the sent objects from src rank.
|
|
1063
|
+
|
|
1064
|
+
For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
1065
|
+
|
|
1066
|
+
Object collectives have a number of serious performance and scalability limitations. See Object collectives for details.
|
|
1067
|
+
|
|
1068
|
+
recv_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.
|
|
1069
|
+
|
|
1070
|
+
Calling recv_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using recv() instead.
|
|
1071
|
+
|
|
1072
|
+
Send or Receive a batch of tensors asynchronously and return a list of requests.
|
|
1073
|
+
|
|
1074
|
+
Process each of the operations in p2p_op_list and return the corresponding requests. NCCL, Gloo, and UCC backend are currently supported.
|
|
1075
|
+
|
|
1076
|
+
p2p_op_list (list[torch.distributed.distributed_c10d.P2POp]) – A list of point-to-point operations(type of each operator is torch.distributed.P2POp). The order of the isend/irecv in the list matters and it needs to match with corresponding isend/irecv on the remote end.
|
|
1077
|
+
|
|
1078
|
+
A list of distributed request objects returned by calling the corresponding op in the op_list.
|
|
1079
|
+
|
|
1080
|
+
list[torch.distributed.distributed_c10d.Work]
|
|
1081
|
+
|
|
1082
|
+
Note that when this API is used with the NCCL PG backend, users must set the current GPU device with torch.cuda.set_device, otherwise it will lead to unexpected hang issues.
|
|
1083
|
+
|
|
1084
|
+
In addition, if this API is the first collective call in the group passed to dist.P2POp, all ranks of the group must participate in this API call; otherwise, the behavior is undefined. If this API call is not the first collective call in the group, batched P2P operations involving only a subset of ranks of the group are allowed.
|
|
1085
|
+
|
|
1086
|
+
A class to build point-to-point operations for batch_isend_irecv.
|
|
1087
|
+
|
|
1088
|
+
This class builds the type of P2P operation, communication buffer, peer rank, Process Group, and tag. Instances of this class will be passed to batch_isend_irecv for point-to-point communications.
|
|
1089
|
+
|
|
1090
|
+
op (Callable) – A function to send data to or receive data from a peer process. The type of op is either torch.distributed.isend or torch.distributed.irecv.
|
|
1091
|
+
|
|
1092
|
+
tensor (Tensor) – Tensor to send or receive.
|
|
1093
|
+
|
|
1094
|
+
peer (int, optional) – Destination or source rank.
|
|
1095
|
+
|
|
1096
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1097
|
+
|
|
1098
|
+
tag (int, optional) – Tag to match send with recv.
|
|
1099
|
+
|
|
1100
|
+
group_peer (int, optional) – Destination or source rank.
|
|
1101
|
+
|
|
1102
|
+
Every collective operation function supports the following two kinds of operations, depending on the setting of the async_op flag passed into the collective:
|
|
1103
|
+
|
|
1104
|
+
Synchronous operation - the default mode, when async_op is set to False. When the function returns, it is guaranteed that the collective operation is performed. In the case of CUDA operations, it is not guaranteed that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream synchronization, see CUDA Semantics. See the below script to see examples of differences in these semantics for CPU and CUDA operations.
|
|
1105
|
+
|
|
1106
|
+
Asynchronous operation - when async_op is set to True. The collective operation function returns a distributed request object. In general, you don’t need to create it manually and it is guaranteed to support two methods:
|
|
1107
|
+
|
|
1108
|
+
is_completed() - in the case of CPU collectives, returns True if completed. In the case of CUDA operations, returns True if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the default stream without further synchronization.
|
|
1109
|
+
|
|
1110
|
+
wait() - in the case of CPU collectives, will block the process until the operation is completed. In the case of CUDA collectives, will block the currently active CUDA stream until the operation is completed (but will not block the CPU).
|
|
1111
|
+
|
|
1112
|
+
get_future() - returns torch._C.Future object. Supported for NCCL, also supported for most operations on GLOO and MPI, except for peer to peer operations. Note: as we continue adopting Futures and merging APIs, get_future() call might become redundant.
|
|
1113
|
+
|
|
1114
|
+
The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. It shows the explicit need to synchronize when using collective outputs on different CUDA streams:
|
|
1115
|
+
|
|
1116
|
+
Broadcasts the tensor to the whole group.
|
|
1117
|
+
|
|
1118
|
+
tensor must have the same number of elements in all processes participating in the collective.
|
|
1119
|
+
|
|
1120
|
+
tensor (Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise.
|
|
1121
|
+
|
|
1122
|
+
src (int) – Source rank on global process group (regardless of group argument).
|
|
1123
|
+
|
|
1124
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1125
|
+
|
|
1126
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1127
|
+
|
|
1128
|
+
group_src (int) – Source rank on group. Must specify one of group_src and src but not both.
|
|
1129
|
+
|
|
1130
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1131
|
+
|
|
1132
|
+
Broadcasts picklable objects in object_list to the whole group.
|
|
1133
|
+
|
|
1134
|
+
Similar to broadcast(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be broadcasted.
|
|
1135
|
+
|
|
1136
|
+
object_list (List[Any]) – List of input objects to broadcast. Each object must be picklable. Only objects on the src rank will be broadcast, but each rank must provide lists of equal sizes.
|
|
1137
|
+
|
|
1138
|
+
src (int) – Source rank from which to broadcast object_list. Source rank is based on global process group (regardless of group argument)
|
|
1139
|
+
|
|
1140
|
+
group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
|
|
1141
|
+
|
|
1142
|
+
device (torch.device, optional) – If not None, the objects are serialized and converted to tensors which are moved to the device before broadcasting. Default is None.
|
|
1143
|
+
|
|
1144
|
+
group_src (int) – Source rank on group. Must not specify one of group_src and src but not both.
|
|
1145
|
+
|
|
1146
|
+
None. If rank is part of the group, object_list will contain the broadcasted objects from src rank.
|
|
1147
|
+
|
|
1148
|
+
For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
1149
|
+
|
|
1150
|
+
Note that this API differs slightly from the broadcast() collective since it does not provide an async_op handle and thus will be a blocking call.
|
|
1151
|
+
|
|
1152
|
+
Object collectives have a number of serious performance and scalability limitations. See Object collectives for details.
|
|
1153
|
+
|
|
1154
|
+
broadcast_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.
|
|
1155
|
+
|
|
1156
|
+
Calling broadcast_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using broadcast() instead.
|
|
1157
|
+
|
|
1158
|
+
Reduces the tensor data across all machines in a way that all get the final result.
|
|
1159
|
+
|
|
1160
|
+
After the call tensor is going to be bitwise identical in all processes.
|
|
1161
|
+
|
|
1162
|
+
Complex tensors are supported.
|
|
1163
|
+
|
|
1164
|
+
tensor (Tensor) – Input and output of the collective. The function operates in-place.
|
|
1165
|
+
|
|
1166
|
+
op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
|
|
1167
|
+
|
|
1168
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1169
|
+
|
|
1170
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1171
|
+
|
|
1172
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1173
|
+
|
|
1174
|
+
Reduces the tensor data across all machines.
|
|
1175
|
+
|
|
1176
|
+
Only the process with rank dst is going to receive the final result.
|
|
1177
|
+
|
|
1178
|
+
tensor (Tensor) – Input and output of the collective. The function operates in-place.
|
|
1179
|
+
|
|
1180
|
+
dst (int) – Destination rank on global process group (regardless of group argument)
|
|
1181
|
+
|
|
1182
|
+
op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
|
|
1183
|
+
|
|
1184
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1185
|
+
|
|
1186
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1187
|
+
|
|
1188
|
+
group_dst (int) – Destination rank on group. Must specify one of group_dst and dst but not both.
|
|
1189
|
+
|
|
1190
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1191
|
+
|
|
1192
|
+
Gathers tensors from the whole group in a list.
|
|
1193
|
+
|
|
1194
|
+
Complex and uneven sized tensors are supported.
|
|
1195
|
+
|
|
1196
|
+
tensor_list (list[Tensor]) – Output list. It should contain correctly-sized tensors to be used for output of the collective. Uneven sized tensors are supported.
|
|
1197
|
+
|
|
1198
|
+
tensor (Tensor) – Tensor to be broadcast from current process.
|
|
1199
|
+
|
|
1200
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1201
|
+
|
|
1202
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1203
|
+
|
|
1204
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1205
|
+
|
|
1206
|
+
Gather tensors from all ranks and put them in a single output tensor.
|
|
1207
|
+
|
|
1208
|
+
This function requires all tensors to be the same size on each process.
|
|
1209
|
+
|
|
1210
|
+
output_tensor (Tensor) – Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the following forms: (i) a concatenation of all the input tensors along the primary dimension; for definition of “concatenation”, see torch.cat(); (ii) a stack of all the input tensors along the primary dimension; for definition of “stack”, see torch.stack(). Examples below may better explain the supported output forms.
|
|
1211
|
+
|
|
1212
|
+
input_tensor (Tensor) – Tensor to be gathered from current rank. Different from the all_gather API, the input tensors in this API must have the same size across all ranks.
|
|
1213
|
+
|
|
1214
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1215
|
+
|
|
1216
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1217
|
+
|
|
1218
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1219
|
+
|
|
1220
|
+
Gathers picklable objects from the whole group into a list.
|
|
1221
|
+
|
|
1222
|
+
Similar to all_gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered.
|
|
1223
|
+
|
|
1224
|
+
object_list (list[Any]) – Output list. It should be correctly sized as the size of the group for this collective and will contain the output.
|
|
1225
|
+
|
|
1226
|
+
obj (Any) – Pickable Python object to be broadcast from current process.
|
|
1227
|
+
|
|
1228
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Default is None.
|
|
1229
|
+
|
|
1230
|
+
None. If the calling rank is part of this group, the output of the collective will be populated into the input object_list. If the calling rank is not part of the group, the passed in object_list will be unmodified.
|
|
1231
|
+
|
|
1232
|
+
Note that this API differs slightly from the all_gather() collective since it does not provide an async_op handle and thus will be a blocking call.
|
|
1233
|
+
|
|
1234
|
+
For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
1235
|
+
|
|
1236
|
+
Object collectives have a number of serious performance and scalability limitations. See Object collectives for details.
|
|
1237
|
+
|
|
1238
|
+
all_gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.
|
|
1239
|
+
|
|
1240
|
+
Calling all_gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using all_gather() instead.
|
|
1241
|
+
|
|
1242
|
+
Gathers a list of tensors in a single process.
|
|
1243
|
+
|
|
1244
|
+
This function requires all tensors to be the same size on each process.
|
|
1245
|
+
|
|
1246
|
+
tensor (Tensor) – Input tensor.
|
|
1247
|
+
|
|
1248
|
+
gather_list (list[Tensor], optional) – List of appropriately, same-sized tensors to use for gathered data (default is None, must be specified on the destination rank)
|
|
1249
|
+
|
|
1250
|
+
dst (int, optional) – Destination rank on global process group (regardless of group argument). (If both dst and group_dst are None, default is global rank 0)
|
|
1251
|
+
|
|
1252
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1253
|
+
|
|
1254
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1255
|
+
|
|
1256
|
+
group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst
|
|
1257
|
+
|
|
1258
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1259
|
+
|
|
1260
|
+
Note that all Tensors in gather_list must have the same size.
|
|
1261
|
+
|
|
1262
|
+
Gathers picklable objects from the whole group in a single process.
|
|
1263
|
+
|
|
1264
|
+
Similar to gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered.
|
|
1265
|
+
|
|
1266
|
+
obj (Any) – Input object. Must be picklable.
|
|
1267
|
+
|
|
1268
|
+
object_gather_list (list[Any]) – Output list. On the dst rank, it should be correctly sized as the size of the group for this collective and will contain the output. Must be None on non-dst ranks. (default is None)
|
|
1269
|
+
|
|
1270
|
+
dst (int, optional) – Destination rank on global process group (regardless of group argument). (If both dst and group_dst are None, default is global rank 0)
|
|
1271
|
+
|
|
1272
|
+
group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
|
|
1273
|
+
|
|
1274
|
+
group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst
|
|
1275
|
+
|
|
1276
|
+
None. On the dst rank, object_gather_list will contain the output of the collective.
|
|
1277
|
+
|
|
1278
|
+
Note that this API differs slightly from the gather collective since it does not provide an async_op handle and thus will be a blocking call.
|
|
1279
|
+
|
|
1280
|
+
For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
1281
|
+
|
|
1282
|
+
Object collectives have a number of serious performance and scalability limitations. See Object collectives for details.
|
|
1283
|
+
|
|
1284
|
+
gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.
|
|
1285
|
+
|
|
1286
|
+
Calling gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using gather() instead.
|
|
1287
|
+
|
|
1288
|
+
Scatters a list of tensors to all processes in a group.
|
|
1289
|
+
|
|
1290
|
+
Each process will receive exactly one tensor and store its data in the tensor argument.
|
|
1291
|
+
|
|
1292
|
+
Complex tensors are supported.
|
|
1293
|
+
|
|
1294
|
+
tensor (Tensor) – Output tensor.
|
|
1295
|
+
|
|
1296
|
+
scatter_list (list[Tensor]) – List of tensors to scatter (default is None, must be specified on the source rank)
|
|
1297
|
+
|
|
1298
|
+
src (int) – Source rank on global process group (regardless of group argument). (If both src and group_src are None, default is global rank 0)
|
|
1299
|
+
|
|
1300
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1301
|
+
|
|
1302
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1303
|
+
|
|
1304
|
+
group_src (int, optional) – Source rank on group. Invalid to specify both src and group_src
|
|
1305
|
+
|
|
1306
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1307
|
+
|
|
1308
|
+
Note that all Tensors in scatter_list must have the same size.
|
|
1309
|
+
|
|
1310
|
+
Scatters picklable objects in scatter_object_input_list to the whole group.
|
|
1311
|
+
|
|
1312
|
+
Similar to scatter(), but Python objects can be passed in. On each rank, the scattered object will be stored as the first element of scatter_object_output_list. Note that all objects in scatter_object_input_list must be picklable in order to be scattered.
|
|
1313
|
+
|
|
1314
|
+
scatter_object_output_list (List[Any]) – Non-empty list whose first element will store the object scattered to this rank.
|
|
1315
|
+
|
|
1316
|
+
scatter_object_input_list (List[Any], optional) – List of input objects to scatter. Each object must be picklable. Only objects on the src rank will be scattered, and the argument can be None for non-src ranks.
|
|
1317
|
+
|
|
1318
|
+
src (int) – Source rank from which to scatter scatter_object_input_list. Source rank is based on global process group (regardless of group argument). (If both src and group_src are None, default is global rank 0)
|
|
1319
|
+
|
|
1320
|
+
group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
|
|
1321
|
+
|
|
1322
|
+
group_src (int, optional) – Source rank on group. Invalid to specify both src and group_src
|
|
1323
|
+
|
|
1324
|
+
None. If rank is part of the group, scatter_object_output_list will have its first element set to the scattered object for this rank.
|
|
1325
|
+
|
|
1326
|
+
Note that this API differs slightly from the scatter collective since it does not provide an async_op handle and thus will be a blocking call.
|
|
1327
|
+
|
|
1328
|
+
Object collectives have a number of serious performance and scalability limitations. See Object collectives for details.
|
|
1329
|
+
|
|
1330
|
+
scatter_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.
|
|
1331
|
+
|
|
1332
|
+
Calling scatter_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using scatter() instead.
|
|
1333
|
+
|
|
1334
|
+
Reduces, then scatters a list of tensors to all processes in a group.
|
|
1335
|
+
|
|
1336
|
+
output (Tensor) – Output tensor.
|
|
1337
|
+
|
|
1338
|
+
input_list (list[Tensor]) – List of tensors to reduce and scatter.
|
|
1339
|
+
|
|
1340
|
+
op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
|
|
1341
|
+
|
|
1342
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1343
|
+
|
|
1344
|
+
async_op (bool, optional) – Whether this op should be an async op.
|
|
1345
|
+
|
|
1346
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.
|
|
1347
|
+
|
|
1348
|
+
Reduces, then scatters a tensor to all ranks in a group.
|
|
1349
|
+
|
|
1350
|
+
output (Tensor) – Output tensor. It should have the same size across all ranks.
|
|
1351
|
+
|
|
1352
|
+
input (Tensor) – Input tensor to be reduced and scattered. Its size should be output tensor size times the world size. The input tensor can have one of the following shapes: (i) a concatenation of the output tensors along the primary dimension, or (ii) a stack of the output tensors along the primary dimension. For definition of “concatenation”, see torch.cat(). For definition of “stack”, see torch.stack().
|
|
1353
|
+
|
|
1354
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1355
|
+
|
|
1356
|
+
async_op (bool, optional) – Whether this op should be an async op.
|
|
1357
|
+
|
|
1358
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.
|
|
1359
|
+
|
|
1360
|
+
Split input tensor and then scatter the split list to all processes in a group.
|
|
1361
|
+
|
|
1362
|
+
Later the received tensors are concatenated from all the processes in the group and returned as a single output tensor.
|
|
1363
|
+
|
|
1364
|
+
Complex tensors are supported.
|
|
1365
|
+
|
|
1366
|
+
output (Tensor) – Gathered concatenated output tensor.
|
|
1367
|
+
|
|
1368
|
+
input (Tensor) – Input tensor to scatter.
|
|
1369
|
+
|
|
1370
|
+
output_split_sizes – (list[Int], optional): Output split sizes for dim 0 if specified None or empty, dim 0 of output tensor must divide equally by world_size.
|
|
1371
|
+
|
|
1372
|
+
input_split_sizes – (list[Int], optional): Input split sizes for dim 0 if specified None or empty, dim 0 of input tensor must divide equally by world_size.
|
|
1373
|
+
|
|
1374
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1375
|
+
|
|
1376
|
+
async_op (bool, optional) – Whether this op should be an async op.
|
|
1377
|
+
|
|
1378
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.
|
|
1379
|
+
|
|
1380
|
+
all_to_all_single is experimental and subject to change.
|
|
1381
|
+
|
|
1382
|
+
Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
|
|
1383
|
+
|
|
1384
|
+
Complex tensors are supported.
|
|
1385
|
+
|
|
1386
|
+
output_tensor_list (list[Tensor]) – List of tensors to be gathered one per rank.
|
|
1387
|
+
|
|
1388
|
+
input_tensor_list (list[Tensor]) – List of tensors to scatter one per rank.
|
|
1389
|
+
|
|
1390
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1391
|
+
|
|
1392
|
+
async_op (bool, optional) – Whether this op should be an async op.
|
|
1393
|
+
|
|
1394
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.
|
|
1395
|
+
|
|
1396
|
+
all_to_all is experimental and subject to change.
|
|
1397
|
+
|
|
1398
|
+
Synchronize all processes.
|
|
1399
|
+
|
|
1400
|
+
This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().
|
|
1401
|
+
|
|
1402
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1403
|
+
|
|
1404
|
+
async_op (bool, optional) – Whether this op should be an async op
|
|
1405
|
+
|
|
1406
|
+
device_ids ([int], optional) – List of device/GPU ids. Only one id is expected.
|
|
1407
|
+
|
|
1408
|
+
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
|
|
1409
|
+
|
|
1410
|
+
ProcessGroupNCCL now blocks the cpu thread till the completion of the barrier collective.
|
|
1411
|
+
|
|
1412
|
+
ProcessGroupNCCL implements barrier as an all_reduce of a 1-element tensor. A device must be chosen for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to device_ids arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device that was first used with this process group, if another collective with tensor inputs has been performed, (4) the device index indicated by the global rank mod local device count.
|
|
1413
|
+
|
|
1414
|
+
Synchronize processes similar to torch.distributed.barrier, but consider a configurable timeout.
|
|
1415
|
+
|
|
1416
|
+
It is able to report ranks that did not pass this barrier within the provided timeout. Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. Rank 0 will block until all send /recv from other ranks are processed, and will report failures for ranks that failed to respond in time. Note that if one rank does not reach the monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.
|
|
1417
|
+
|
|
1418
|
+
This collective will block all processes/ranks in the group, until the whole group exits the function successfully, making it useful for debugging and synchronizing. However, it can have a performance impact and should only be used for debugging or scenarios that require full synchronization points on the host-side. For debugging purposes, this barrier can be inserted before the application’s collective calls to check if any ranks are desynchronized.
|
|
1419
|
+
|
|
1420
|
+
Note that this collective is only supported with the GLOO backend.
|
|
1421
|
+
|
|
1422
|
+
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
|
|
1423
|
+
|
|
1424
|
+
timeout (datetime.timedelta, optional) – Timeout for monitored_barrier. If None, the default process group timeout will be used.
|
|
1425
|
+
|
|
1426
|
+
wait_all_ranks (bool, optional) – Whether to collect all failed ranks or not. By default, this is False and monitored_barrier on rank 0 will throw on the first failed rank it encounters in order to fail fast. By setting wait_all_ranks=True monitored_barrier will collect all failed ranks and throw an error containing information about all failed ranks.
|
|
1427
|
+
|
|
1428
|
+
A Work object represents the handle to a pending asynchronous operation in PyTorch’s distributed package. It is returned by non-blocking collective operations, such as dist.all_reduce(tensor, async_op=True).
|
|
1429
|
+
|
|
1430
|
+
Blocks the currently active GPU stream on the operation to complete. For GPU based collectives this is equivalent to synchronize. For CPU initiated collectives such as with Gloo this will block the CUDA stream until the operation is complete.
|
|
1431
|
+
|
|
1432
|
+
This returns immediately in all cases.
|
|
1433
|
+
|
|
1434
|
+
To check whether an operation was successful you should check the Work object result asynchronously.
|
|
1435
|
+
|
|
1436
|
+
A torch.futures.Future object which is associated with the completion of the Work. As an example, a future object can be retrieved by fut = process_group.allreduce(tensors).get_future().
|
|
1437
|
+
|
|
1438
|
+
Below is an example of a simple allreduce DDP communication hook that uses get_future API to retrieve a Future associated with the completion of allreduce.
|
|
1439
|
+
|
|
1440
|
+
get_future API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a torch.futures.Future.
|
|
1441
|
+
|
|
1442
|
+
In the example above, allreduce work will be done on GPU using NCCL backend, fut.wait() will return after synchronizing the appropriate NCCL streams with PyTorch’s current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that CUDAFuture does not support TORCH_NCCL_BLOCKING_WAIT flag or NCCL’s barrier(). In addition, if a callback function was added by fut.then(), it will wait until WorkNCCL’s NCCL streams synchronize with ProcessGroupNCCL’s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. fut.then() will return another CUDAFuture that holds the return value of the callback and a CUDAEvent that recorded the callback stream.
|
|
1443
|
+
|
|
1444
|
+
For CPU work, fut.done() returns true when work has been completed and value() tensors are ready.
|
|
1445
|
+
|
|
1446
|
+
For GPU work, fut.done() returns true only whether the operation has been enqueued.
|
|
1447
|
+
|
|
1448
|
+
For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), fut.done() returns true when tensors have arrived on respective nodes, but not yet necessarily synched on respective GPUs (similarly to GPU work).
|
|
1449
|
+
|
|
1450
|
+
A torch.futures.Future object of int type which maps to the enum type of WorkResult As an example, a future object can be retrieved by fut = process_group.allreduce(tensor).get_future_result().
|
|
1451
|
+
|
|
1452
|
+
users can use fut.wait() to blocking wait for the completion of the work and get the WorkResult by fut.value(). Also, users can use fut.then(call_back_func) to register a callback function to be called when the work is completed, without blocking the current thread.
|
|
1453
|
+
|
|
1454
|
+
get_future_result API supports NCCL
|
|
1455
|
+
|
|
1456
|
+
In normal cases, users do not need to set the timeout. calling wait() is the same as calling synchronize(): Letting the current stream block on the completion of the NCCL work. However, if timeout is set, it will block the CPU thread until the NCCL work is completed or timed out. If timeout, exception will be thrown.
|
|
1457
|
+
|
|
1458
|
+
An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR, and PREMUL_SUM.
|
|
1459
|
+
|
|
1460
|
+
BAND, BOR, and BXOR reductions are not available when using the NCCL backend.
|
|
1461
|
+
|
|
1462
|
+
AVG divides values by the world size before summing across ranks. AVG is only available with the NCCL backend, and only for NCCL versions 2.10 or later.
|
|
1463
|
+
|
|
1464
|
+
PREMUL_SUM multiplies inputs by a given scalar locally before reduction. PREMUL_SUM is only available with the NCCL backend, and only available for NCCL versions 2.11 or later. Users are supposed to use torch.distributed._make_nccl_premul_sum.
|
|
1465
|
+
|
|
1466
|
+
Additionally, MAX, MIN and PRODUCT are not supported for complex tensors.
|
|
1467
|
+
|
|
1468
|
+
The values of this class can be accessed as attributes, e.g., ReduceOp.SUM. They are used in specifying strategies for reduction collectives, e.g., reduce().
|
|
1469
|
+
|
|
1470
|
+
This class does not support __members__ property.
|
|
1471
|
+
|
|
1472
|
+
Deprecated enum-like class for reduction operations: SUM, PRODUCT, MIN, and MAX.
|
|
1473
|
+
|
|
1474
|
+
ReduceOp is recommended to use instead.
|
|
1475
|
+
|
|
1476
|
+
The distributed package comes with a distributed key-value store, which can be used to share information between processes in the group as well as to initialize the distributed package in torch.distributed.init_process_group() (by explicitly creating the store as an alternative to specifying init_method.) There are 3 choices for Key-Value Stores: TCPStore, FileStore, and HashStore.
|
|
1477
|
+
|
|
1478
|
+
Base class for all store implementations, such as the 3 provided by PyTorch distributed: (TCPStore, FileStore, and HashStore).
|
|
1479
|
+
|
|
1480
|
+
The first call to add for a given key creates a counter associated with key in the store, initialized to amount. Subsequent calls to add with the same key increment the counter by the specified amount. Calling add() with a key that has already been set in the store by set() will result in an exception.
|
|
1481
|
+
|
|
1482
|
+
key (str) – The key in the store whose counter will be incremented.
|
|
1483
|
+
|
|
1484
|
+
amount (int) – The quantity by which the counter will be incremented.
|
|
1485
|
+
|
|
1486
|
+
Append the key-value pair into the store based on the supplied key and value. If key does not exists in the store, it will be created.
|
|
1487
|
+
|
|
1488
|
+
key (str) – The key to be appended to the store.
|
|
1489
|
+
|
|
1490
|
+
value (str) – The value associated with key to be added to the store.
|
|
1491
|
+
|
|
1492
|
+
The call to check whether a given list of keys have value stored in the store. This call immediately returns in normal cases but still suffers from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed. Calling check() with a list of keys that one wants to check whether stored in the store or not.
|
|
1493
|
+
|
|
1494
|
+
keys (list[str]) – The keys to query whether stored in the store.
|
|
1495
|
+
|
|
1496
|
+
Clones the store and returns a new object that points to the same underlying store. The returned store can be used concurrently with the original object. This is intended to provide a safe way to use a store from multiple threads by cloning one store per thread.
|
|
1497
|
+
|
|
1498
|
+
Inserts the key-value pair into the store based on the supplied key and performs comparison between expected_value and desired_value before inserting. desired_value will only be set if expected_value for the key already exists in the store or if expected_value is an empty string.
|
|
1499
|
+
|
|
1500
|
+
key (str) – The key to be checked in the store.
|
|
1501
|
+
|
|
1502
|
+
expected_value (str) – The value associated with key to be checked before insertion.
|
|
1503
|
+
|
|
1504
|
+
desired_value (str) – The value associated with key to be added to the store.
|
|
1505
|
+
|
|
1506
|
+
Deletes the key-value pair associated with key from the store. Returns true if the key was successfully deleted, and false if it was not.
|
|
1507
|
+
|
|
1508
|
+
The delete_key API is only supported by the TCPStore and HashStore. Using this API with the FileStore will result in an exception.
|
|
1509
|
+
|
|
1510
|
+
key (str) – The key to be deleted from the store
|
|
1511
|
+
|
|
1512
|
+
True if key was deleted, otherwise False.
|
|
1513
|
+
|
|
1514
|
+
Retrieves the value associated with the given key in the store. If key is not present in the store, the function will wait for timeout, which is defined when initializing the store, before throwing an exception.
|
|
1515
|
+
|
|
1516
|
+
key (str) – The function will return the value associated with this key.
|
|
1517
|
+
|
|
1518
|
+
Value associated with key if key is in the store.
|
|
1519
|
+
|
|
1520
|
+
Returns true if the store supports extended operations.
|
|
1521
|
+
|
|
1522
|
+
Retrieve all values in keys. If any key in keys is not present in the store, the function will wait for timeout
|
|
1523
|
+
|
|
1524
|
+
keys (List[str]) – The keys to be retrieved from the store.
|
|
1525
|
+
|
|
1526
|
+
Inserts a list key-value pair into the store based on the supplied keys and values
|
|
1527
|
+
|
|
1528
|
+
keys (List[str]) – The keys to insert.
|
|
1529
|
+
|
|
1530
|
+
values (List[str]) – The values to insert.
|
|
1531
|
+
|
|
1532
|
+
Returns the number of keys set in the store. Note that this number will typically be one greater than the number of keys added by set() and add() since one key is used to coordinate all the workers using the store.
|
|
1533
|
+
|
|
1534
|
+
When used with the TCPStore, num_keys returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.
|
|
1535
|
+
|
|
1536
|
+
The number of keys present in the store.
|
|
1537
|
+
|
|
1538
|
+
Returns the length of the specified queue.
|
|
1539
|
+
|
|
1540
|
+
If the queue doesn’t exist it returns 0.
|
|
1541
|
+
|
|
1542
|
+
See queue_push for more details.
|
|
1543
|
+
|
|
1544
|
+
key (str) – The key of the queue to get the length.
|
|
1545
|
+
|
|
1546
|
+
Pops a value from the specified queue or waits until timeout if the queue is empty.
|
|
1547
|
+
|
|
1548
|
+
See queue_push for more details.
|
|
1549
|
+
|
|
1550
|
+
If block is False, a dist.QueueEmptyError will be raised if the queue is empty.
|
|
1551
|
+
|
|
1552
|
+
key (str) – The key of the queue to pop from.
|
|
1553
|
+
|
|
1554
|
+
block (bool) – Whether to block waiting for the key or immediately return.
|
|
1555
|
+
|
|
1556
|
+
Pushes a value into the specified queue.
|
|
1557
|
+
|
|
1558
|
+
Using the same key for queues and set/get operations may result in unexpected behavior.
|
|
1559
|
+
|
|
1560
|
+
wait/check operations are supported for queues.
|
|
1561
|
+
|
|
1562
|
+
wait with queues will only wake one waiting worker rather than all.
|
|
1563
|
+
|
|
1564
|
+
key (str) – The key of the queue to push to.
|
|
1565
|
+
|
|
1566
|
+
value (str) – The value to push into the queue.
|
|
1567
|
+
|
|
1568
|
+
Inserts the key-value pair into the store based on the supplied key and value. If key already exists in the store, it will overwrite the old value with the new supplied value.
|
|
1569
|
+
|
|
1570
|
+
key (str) – The key to be added to the store.
|
|
1571
|
+
|
|
1572
|
+
value (str) – The value associated with key to be added to the store.
|
|
1573
|
+
|
|
1574
|
+
Sets the store’s default timeout. This timeout is used during initialization and in wait() and get().
|
|
1575
|
+
|
|
1576
|
+
timeout (timedelta) – timeout to be set in the store.
|
|
1577
|
+
|
|
1578
|
+
Gets the timeout of the store.
|
|
1579
|
+
|
|
1580
|
+
wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str]) -> None
|
|
1581
|
+
|
|
1582
|
+
Waits for each key in keys to be added to the store. If not all keys are set before the timeout (set during store initialization), then wait will throw an exception.
|
|
1583
|
+
|
|
1584
|
+
keys (list) – List of keys on which to wait until they are set in the store.
|
|
1585
|
+
|
|
1586
|
+
wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str], arg1: datetime.timedelta) -> None
|
|
1587
|
+
|
|
1588
|
+
Waits for each key in keys to be added to the store, and throws an exception if the keys have not been set by the supplied timeout.
|
|
1589
|
+
|
|
1590
|
+
keys (list) – List of keys on which to wait until they are set in the store.
|
|
1591
|
+
|
|
1592
|
+
timeout (timedelta) – Time to wait for the keys to be added before throwing an exception.
|
|
1593
|
+
|
|
1594
|
+
A TCP-based distributed key-value store implementation. The server store holds the data, while the client stores can connect to the server store over TCP and perform actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc. There should always be one server store initialized because the client store(s) will wait for the server to establish a connection.
|
|
1595
|
+
|
|
1596
|
+
host_name (str) – The hostname or IP Address the server store should run on.
|
|
1597
|
+
|
|
1598
|
+
port (int) – The port on which the server store should listen for incoming requests.
|
|
1599
|
+
|
|
1600
|
+
world_size (int, optional) – The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users).
|
|
1601
|
+
|
|
1602
|
+
is_master (bool, optional) – True when initializing the server store and False for client stores. Default is False.
|
|
1603
|
+
|
|
1604
|
+
timeout (timedelta, optional) – Timeout used by the store during initialization and for methods such as get() and wait(). Default is timedelta(seconds=300)
|
|
1605
|
+
|
|
1606
|
+
wait_for_workers (bool, optional) – Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True.
|
|
1607
|
+
|
|
1608
|
+
multi_tenant (bool, optional) – If True, all TCPStore instances in the current process with the same host/port will use the same underlying TCPServer. Default is False.
|
|
1609
|
+
|
|
1610
|
+
master_listen_fd (int, optional) – If specified, the underlying TCPServer will listen on this file descriptor, which must be a socket already bound to port. To bind an ephemeral port we recommend setting the port to 0 and reading .port. Default is None (meaning the server creates a new socket and attempts to bind it to port).
|
|
1611
|
+
|
|
1612
|
+
use_libuv (bool, optional) – If True, use libuv for TCPServer backend. Default is True.
|
|
1613
|
+
|
|
1614
|
+
Creates a new TCPStore.
|
|
1615
|
+
|
|
1616
|
+
Gets the hostname on which the store listens for requests.
|
|
1617
|
+
|
|
1618
|
+
Returns True if it’s using the libuv backend.
|
|
1619
|
+
|
|
1620
|
+
Gets the port number on which the store listens for requests.
|
|
1621
|
+
|
|
1622
|
+
A thread-safe store implementation based on an underlying hashmap. This store can be used within the same process (for example, by other threads), but cannot be used across processes.
|
|
1623
|
+
|
|
1624
|
+
Creates a new HashStore.
|
|
1625
|
+
|
|
1626
|
+
A store implementation that uses a file to store the underlying key-value pairs.
|
|
1627
|
+
|
|
1628
|
+
file_name (str) – path of the file in which to store the key-value pairs
|
|
1629
|
+
|
|
1630
|
+
world_size (int, optional) – The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users).
|
|
1631
|
+
|
|
1632
|
+
Creates a new FileStore.
|
|
1633
|
+
|
|
1634
|
+
Gets the path of the file used by FileStore to store key-value pairs.
|
|
1635
|
+
|
|
1636
|
+
A wrapper around any of the 3 key-value stores (TCPStore, FileStore, and HashStore) that adds a prefix to each key inserted to the store.
|
|
1637
|
+
|
|
1638
|
+
prefix (str) – The prefix string that is prepended to each key before being inserted into the store.
|
|
1639
|
+
|
|
1640
|
+
store (torch.distributed.store) – A store object that forms the underlying key-value store.
|
|
1641
|
+
|
|
1642
|
+
Creates a new PrefixStore.
|
|
1643
|
+
|
|
1644
|
+
Gets the underlying store object that PrefixStore wraps around.
|
|
1645
|
+
|
|
1646
|
+
Note that you can use torch.profiler (recommended, only available after 1.8.1) or torch.autograd.profiler to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (gloo, nccl, mpi) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator:
|
|
1647
|
+
|
|
1648
|
+
Please refer to the profiler documentation for a full overview of profiler features.
|
|
1649
|
+
|
|
1650
|
+
The multi-GPU functions (which stand for multiple GPUs per CPU thread) are deprecated. As of today, PyTorch Distributed’s preferred programming model is one device per thread, as exemplified by the APIs in this document. If you are a backend developer and want to support multiple devices per thread, please contact PyTorch Distributed’s maintainers.
|
|
1651
|
+
|
|
1652
|
+
Object collectives have a number of serious limitations. Read further to determine if they are safe to use for your use case.
|
|
1653
|
+
|
|
1654
|
+
Object collectives are a set of collective-like operations that work on arbitrary Python objects, as long as they can be pickled. There are various collective patterns implemented (e.g. broadcast, all_gather, …) but they each roughly follow this pattern:
|
|
1655
|
+
|
|
1656
|
+
convert the input object into a pickle (raw bytes), then shove it into a byte tensor
|
|
1657
|
+
|
|
1658
|
+
communicate the size of this byte tensor to peers (first collective operation)
|
|
1659
|
+
|
|
1660
|
+
allocate appropriately sized tensor to perform the real collective
|
|
1661
|
+
|
|
1662
|
+
communicate the object data (second collective operation)
|
|
1663
|
+
|
|
1664
|
+
convert raw data back into Python (unpickle)
|
|
1665
|
+
|
|
1666
|
+
Object collectives sometimes have surprising performance or memory characteristics that lead to long runtimes or OOMs, and thus they should be used with caution. Here are some common issues.
|
|
1667
|
+
|
|
1668
|
+
Asymmetric pickle/unpickle time - Pickling objects can be slow, depending on the number, type and size of the objects. When the collective has a fan-in (e.g. gather_object), the receiving rank(s) must unpickle N times more objects than the sending rank(s) had to pickle, which can cause other ranks to time out on their next collective.
|
|
1669
|
+
|
|
1670
|
+
Inefficient tensor communication - Tensors should be sent via regular collective APIs, not object collective APIs. It is possible to send Tensors via object collective APIs, but they will be serialized and deserialized (including a CPU-sync and device-to-host copy in the case of non-CPU tensors), and in almost every case other than debugging or troubleshooting code, it would be worth the trouble to refactor the code to use non-object collectives instead.
|
|
1671
|
+
|
|
1672
|
+
Unexpected tensor devices - If you still want to send tensors via object collectives, there is another aspect specific to cuda (and possibly other accelerators) tensors. If you pickle a tensor that is currently on cuda:3, and then unpickle it, you will get another tensor on cuda:3 regardless of which process you are on, or which CUDA device is the ‘default’ device for that process. With regular tensor collective APIs, ‘output tensors’ will always be on the same, local device, which is generally what you’d expect.
|
|
1673
|
+
|
|
1674
|
+
Unpickling a tensor will implicitly activate a CUDA context if it is the first time a GPU is used by the process, which can waste significant amounts of GPU memory. This issue can be avoided by moving tensors to CPU before passing them as inputs to an object collective.
|
|
1675
|
+
|
|
1676
|
+
Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends through a run-time register mechanism. For references on how to develop a third-party backend through C++ Extension, please refer to Tutorials - Custom C++ and CUDA Extensions and test/cpp_extensions/cpp_c10d_extension.cpp. The capability of third-party backends are decided by their own implementations.
|
|
1677
|
+
|
|
1678
|
+
The new backend derives from c10d::ProcessGroup and registers the backend name and the instantiating interface through torch.distributed.Backend.register_backend() when imported.
|
|
1679
|
+
|
|
1680
|
+
When manually importing this backend and invoking torch.distributed.init_process_group() with the corresponding backend name, the torch.distributed package runs on the new backend.
|
|
1681
|
+
|
|
1682
|
+
The support of third-party backend is experimental and subject to change.
|
|
1683
|
+
|
|
1684
|
+
The torch.distributed package also provides a launch utility in torch.distributed.launch. This helper utility can be used to launch multiple processes per node for distributed training.
|
|
1685
|
+
|
|
1686
|
+
Module torch.distributed.launch.
|
|
1687
|
+
|
|
1688
|
+
torch.distributed.launch is a module that spawns up multiple distributed training processes on each of the training nodes.
|
|
1689
|
+
|
|
1690
|
+
This module is going to be deprecated in favor of torchrun.
|
|
1691
|
+
|
|
1692
|
+
The utility can be used for single-node distributed training, in which one or more processes per node will be spawned. The utility can be used for either CPU training or GPU training. If the utility is used for GPU training, each distributed process will be operating on a single GPU. This can achieve well-improved single-node training performance. It can also be used in multi-node distributed training, by spawning up multiple processes on each node for well-improved multi-node distributed training performance as well. This will especially be beneficial for systems with multiple Infiniband interfaces that have direct-GPU support, since all of them can be utilized for aggregated communication bandwidth.
|
|
1693
|
+
|
|
1694
|
+
In both cases of single-node distributed training or multi-node distributed training, this utility will launch the given number of processes per node (--nproc-per-node). If used for GPU training, this number needs to be less or equal to the number of GPUs on the current system (nproc_per_node), and each process will be operating on a single GPU from GPU 0 to GPU (nproc_per_node - 1).
|
|
1695
|
+
|
|
1696
|
+
How to use this module:
|
|
1697
|
+
|
|
1698
|
+
Single-Node multi-process distributed training
|
|
1699
|
+
|
|
1700
|
+
Multi-Node multi-process distributed training: (e.g. two nodes)
|
|
1701
|
+
|
|
1702
|
+
Node 1: (IP: 192.168.1.1, and has a free port: 1234)
|
|
1703
|
+
|
|
1704
|
+
To look up what optional arguments this module offers:
|
|
1705
|
+
|
|
1706
|
+
1. This utility and multi-process distributed (single-node or multi-node) GPU training currently only achieves the best performance using the NCCL distributed backend. Thus NCCL backend is the recommended backend to use for GPU training.
|
|
1707
|
+
|
|
1708
|
+
2. In your training program, you must parse the command-line argument: --local-rank=LOCAL_PROCESS_RANK, which will be provided by this module. If your training program uses GPUs, you should ensure that your code only runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
|
|
1709
|
+
|
|
1710
|
+
Parsing the local_rank argument
|
|
1711
|
+
|
|
1712
|
+
Set your device to local rank using either
|
|
1713
|
+
|
|
1714
|
+
Changed in version 2.0.0: The launcher will passes the --local-rank=<rank> argument to your script. From PyTorch 2.0.0 onwards, the dashed --local-rank is preferred over the previously used underscored --local_rank.
|
|
1715
|
+
|
|
1716
|
+
For backward compatibility, it may be necessary for users to handle both cases in their argument parsing code. This means including both "--local-rank" and "--local_rank" in the argument parser. If only "--local_rank" is provided, the launcher will trigger an error: “error: unrecognized arguments: –local-rank=<rank>”. For training code that only supports PyTorch 2.0.0+, including "--local-rank" should be sufficient.
|
|
1717
|
+
|
|
1718
|
+
3. In your training program, you are supposed to call the following function at the beginning to start the distributed backend. It is strongly recommended that init_method=env://. Other init methods (e.g. tcp://) may work, but env:// is the one that is officially supported by this module.
|
|
1719
|
+
|
|
1720
|
+
4. In your training program, you can either use regular distributed functions or use torch.nn.parallel.DistributedDataParallel() module. If your training program uses GPUs for training and you would like to use torch.nn.parallel.DistributedDataParallel() module, here is how to configure it.
|
|
1721
|
+
|
|
1722
|
+
Please ensure that device_ids argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the process. In other words, the device_ids needs to be [args.local_rank], and output_device needs to be args.local_rank in order to use this utility
|
|
1723
|
+
|
|
1724
|
+
5. Another way to pass local_rank to the subprocesses via environment variable LOCAL_RANK. This behavior is enabled when you launch the script with --use-env=True. You must adjust the subprocess example above to replace args.local_rank with os.environ['LOCAL_RANK']; the launcher will not pass --local-rank when you specify this flag.
|
|
1725
|
+
|
|
1726
|
+
local_rank is NOT globally unique: it is only unique per process on a machine. Thus, don’t use it to decide if you should, e.g., write to a networked filesystem. See pytorch/pytorch#12042 for an example of how things can go wrong if you don’t do this correctly.
|
|
1727
|
+
|
|
1728
|
+
The Multiprocessing package - torch.multiprocessing package also provides a spawn function in torch.multiprocessing.spawn(). This helper function can be used to spawn multiple processes. It works by passing in the function that you want to run and spawns N processes to run it. This can be used for multiprocess distributed training as well.
|
|
1729
|
+
|
|
1730
|
+
For references on how to use it, please refer to PyTorch example - ImageNet implementation
|
|
1731
|
+
|
|
1732
|
+
Note that this function requires Python 3.4 or higher.
|
|
1733
|
+
|
|
1734
|
+
Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks. torch.distributed provides a suite of tools to help debug training applications in a self-serve fashion:
|
|
1735
|
+
|
|
1736
|
+
It is extremely convenient to use python’s debugger in a distributed environment, but because it does not work out of the box many people do not use it at all. PyTorch offers a customized wrapper around pdb that streamlines the process.
|
|
1737
|
+
|
|
1738
|
+
torch.distributed.breakpoint makes this process easy. Internally, it customizes pdb’s breakpoint behavior in two ways but otherwise behaves as normal pdb.
|
|
1739
|
+
|
|
1740
|
+
Attaches the debugger only on one rank (specified by the user).
|
|
1741
|
+
|
|
1742
|
+
Ensures all other ranks stop, by using a torch.distributed.barrier() that will release once the debugged rank issues a continue
|
|
1743
|
+
|
|
1744
|
+
Reroutes stdin from the child process such that it connects to your terminal.
|
|
1745
|
+
|
|
1746
|
+
To use it, simply issue torch.distributed.breakpoint(rank) on all ranks, using the same value for rank in each case.
|
|
1747
|
+
|
|
1748
|
+
As of v1.10, torch.distributed.monitored_barrier() exists as an alternative to torch.distributed.barrier() which fails with helpful information about which rank may be faulty when crashing, i.e. not all ranks calling into torch.distributed.monitored_barrier() within the provided timeout. torch.distributed.monitored_barrier() implements a host-side barrier using send/recv communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledge the barrier in time. As an example, consider the following function where rank 1 fails to call into torch.distributed.monitored_barrier() (in practice this could be due to an application bug or hang in a previous collective):
|
|
1749
|
+
|
|
1750
|
+
The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further:
|
|
1751
|
+
|
|
1752
|
+
With TORCH_CPP_LOG_LEVEL=INFO, the environment variable TORCH_DISTRIBUTED_DEBUG can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks are synchronized appropriately. TORCH_DISTRIBUTED_DEBUG can be set to either OFF (default), INFO, or DETAIL depending on the debugging level required. Please note that the most verbose option, DETAIL may impact the application performance and thus should only be used when debugging issues.
|
|
1753
|
+
|
|
1754
|
+
Setting TORCH_DISTRIBUTED_DEBUG=INFO will result in additional debug logging when models trained with torch.nn.parallel.DistributedDataParallel() are initialized, and TORCH_DISTRIBUTED_DEBUG=DETAIL will additionally log runtime performance statistics a select number of iterations. These runtime statistics include data such as forward time, backward time, gradient communication time, etc. As an example, given the following application:
|
|
1755
|
+
|
|
1756
|
+
The following logs are rendered at initialization time:
|
|
1757
|
+
|
|
1758
|
+
The following logs are rendered during runtime (when TORCH_DISTRIBUTED_DEBUG=DETAIL is set):
|
|
1759
|
+
|
|
1760
|
+
In addition, TORCH_DISTRIBUTED_DEBUG=INFO enhances crash logging in torch.nn.parallel.DistributedDataParallel() due to unused parameters in the model. Currently, find_unused_parameters=True must be passed into torch.nn.parallel.DistributedDataParallel() initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are required to be used in loss computation as torch.nn.parallel.DistributedDataParallel() does not support unused parameters in the backwards pass. These constraints are challenging especially for larger models, thus when crashing with an error, torch.nn.parallel.DistributedDataParallel() will log the fully qualified name of all parameters that went unused. For example, in the above application, if we modify loss to be instead computed as loss = output[1], then TwoLinLayerNet.a does not receive a gradient in the backwards pass, and thus results in DDP failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models:
|
|
1761
|
+
|
|
1762
|
+
Setting TORCH_DISTRIBUTED_DEBUG=DETAIL will trigger additional consistency and synchronization checks on every collective call issued by the user either directly or indirectly (such as DDP allreduce). This is done by creating a wrapper process group that wraps all process groups returned by torch.distributed.init_process_group() and torch.distributed.new_group() APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular process group, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include a torch.distributed.monitored_barrier(), which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency by ensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when the application crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes into torch.distributed.all_reduce():
|
|
1763
|
+
|
|
1764
|
+
With the NCCL backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enables TORCH_DISTRIBUTED_DEBUG=DETAIL and reruns the application, the following error message reveals the root cause:
|
|
1765
|
+
|
|
1766
|
+
For fine-grained control of the debug level during runtime the functions torch.distributed.set_debug_level(), torch.distributed.set_debug_level_from_env(), and torch.distributed.get_debug_level() can also be used.
|
|
1767
|
+
|
|
1768
|
+
In addition, TORCH_DISTRIBUTED_DEBUG=DETAIL can be used in conjunction with TORCH_SHOW_CPP_STACKTRACES=1 to log the entire callstack when a collective desynchronization is detected. These collective desynchronization checks will work for all applications that use c10d collective calls backed by process groups created with the torch.distributed.init_process_group() and torch.distributed.new_group() APIs.
|
|
1769
|
+
|
|
1770
|
+
In addition to explicit debugging support via torch.distributed.monitored_barrier() and TORCH_DISTRIBUTED_DEBUG, the underlying C++ library of torch.distributed also outputs log messages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. The following matrix shows how the log level can be adjusted via the combination of TORCH_CPP_LOG_LEVEL and TORCH_DISTRIBUTED_DEBUG environment variables.
|
|
1771
|
+
|
|
1772
|
+
TORCH_DISTRIBUTED_DEBUG
|
|
1773
|
+
|
|
1774
|
+
Distributed components raise custom Exception types derived from RuntimeError:
|
|
1775
|
+
|
|
1776
|
+
torch.distributed.DistError: This is the base type of all distributed exceptions.
|
|
1777
|
+
|
|
1778
|
+
torch.distributed.DistBackendError: This exception is thrown when a backend-specific error occurs. For example, if the NCCL backend is used and the user attempts to use a GPU that is not available to the NCCL library.
|
|
1779
|
+
|
|
1780
|
+
torch.distributed.DistNetworkError: This exception is thrown when networking libraries encounter errors (ex: Connection reset by peer)
|
|
1781
|
+
|
|
1782
|
+
torch.distributed.DistStoreError: This exception is thrown when the Store encounters an error (ex: TCPStore timeout)
|
|
1783
|
+
|
|
1784
|
+
Exception raised when an error occurs in the distributed library
|
|
1785
|
+
|
|
1786
|
+
Exception raised when a backend error occurs in distributed
|
|
1787
|
+
|
|
1788
|
+
Exception raised when a network error occurs in distributed
|
|
1789
|
+
|
|
1790
|
+
Exception raised when an error occurs in the distributed store
|
|
1791
|
+
|
|
1792
|
+
If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank:
|
|
1793
|
+
|
|
1794
|
+
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing.
|
|
1795
|
+
|
|
1796
|
+
rank (int) – Which rank to break on. Default: 0
|
|
1797
|
+
|
|
1798
|
+
skip (int) – Skip the first skip calls to this breakpoint. Default: 0.
|
|
1799
|
+
|
|
1800
|
+
---
|
|
1801
|
+
|
|
1802
|
+
## DistributedDataParallel#
|
|
1803
|
+
|
|
1804
|
+
**URL:** https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
|
1805
|
+
|
|
1806
|
+
**Contents:**
|
|
1807
|
+
- DistributedDataParallel#
|
|
1808
|
+
|
|
1809
|
+
Implement distributed data parallelism based on torch.distributed at module level.
|
|
1810
|
+
|
|
1811
|
+
This container provides data parallelism by synchronizing gradients across each model replica. The devices to synchronize across are specified by the input process_group, which is the entire world by default. Note that DistributedDataParallel does not chunk or otherwise shard the input across participating GPUs; the user is responsible for defining how to do so, for example through the use of a DistributedSampler.
|
|
1812
|
+
|
|
1813
|
+
See also: Basics and Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel. The same constraints on input as in torch.nn.DataParallel apply.
|
|
1814
|
+
|
|
1815
|
+
Creation of this class requires that torch.distributed to be already initialized, by calling torch.distributed.init_process_group().
|
|
1816
|
+
|
|
1817
|
+
DistributedDataParallel is proven to be significantly faster than torch.nn.DataParallel for single-node multi-GPU data parallel training.
|
|
1818
|
+
|
|
1819
|
+
To use DistributedDataParallel on a host with N GPUs, you should spawn up N processes, ensuring that each process exclusively works on a single GPU from 0 to N-1. This can be done by either setting CUDA_VISIBLE_DEVICES for every process or by calling the following API for GPUs,
|
|
1820
|
+
|
|
1821
|
+
or calling the unified API for accelerator,
|
|
1822
|
+
|
|
1823
|
+
where i is from 0 to N-1. In each process, you should refer the following to construct this module:
|
|
1824
|
+
|
|
1825
|
+
Or you can use the latest API for initialization:
|
|
1826
|
+
|
|
1827
|
+
In order to spawn up multiple processes per node, you can use either torch.distributed.launch or torch.multiprocessing.spawn.
|
|
1828
|
+
|
|
1829
|
+
Please refer to PyTorch Distributed Overview for a brief introduction to all features related to distributed training.
|
|
1830
|
+
|
|
1831
|
+
DistributedDataParallel can be used in conjunction with torch.distributed.optim.ZeroRedundancyOptimizer to reduce per-rank optimizer states memory footprint. Please refer to ZeroRedundancyOptimizer recipe for more details.
|
|
1832
|
+
|
|
1833
|
+
nccl backend is currently the fastest and highly recommended backend when using GPUs. This applies to both single-node and multi-node distributed training.
|
|
1834
|
+
|
|
1835
|
+
This module also supports mixed-precision distributed training. This means that your model can have different types of parameters such as mixed types of fp16 and fp32, the gradient reduction on these mixed types of parameters will just work fine.
|
|
1836
|
+
|
|
1837
|
+
If you use torch.save on one process to checkpoint the module, and torch.load on some other processes to recover it, make sure that map_location is configured properly for every process. Without map_location, torch.load would recover the module to devices where the module was saved from.
|
|
1838
|
+
|
|
1839
|
+
When a model is trained on M nodes with batch=N, the gradient will be M times smaller when compared to the same model trained on a single node with batch=M*N if the loss is summed (NOT averaged as usual) across instances in a batch (because the gradients between different nodes are averaged). You should take this into consideration when you want to obtain a mathematically equivalent training process compared to the local training counterpart. But in most cases, you can just treat a DistributedDataParallel wrapped model, a DataParallel wrapped model and an ordinary model on a single GPU as the same (E.g. using the same learning rate for equivalent batch size).
|
|
1840
|
+
|
|
1841
|
+
Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers (e.g. BatchNorm stats) are broadcast from the module in process of rank 0, to all other replicas in the system in every iteration.
|
|
1842
|
+
|
|
1843
|
+
If you are using DistributedDataParallel in conjunction with the Distributed RPC Framework, you should always use torch.distributed.autograd.backward() to compute gradients and torch.distributed.optim.DistributedOptimizer for optimizing parameters.
|
|
1844
|
+
|
|
1845
|
+
DistributedDataParallel currently offers limited support for gradient checkpointing with torch.utils.checkpoint(). If the checkpoint is done with use_reentrant=False (recommended), DDP will work as expected without any limitations. If, however, the checkpoint is done with use_reentrant=True (the default), DDP will work as expected when there are no unused parameters in the model and each layer is checkpointed at most once (make sure you are not passing find_unused_parameters=True to DDP). We currently do not support the case where a layer is checkpointed multiple times, or when there unused parameters in the checkpointed model.
|
|
1846
|
+
|
|
1847
|
+
To let a non-DDP model load a state dict from a DDP model, consume_prefix_in_state_dict_if_present() needs to be applied to strip the prefix “module.” in the DDP state dict before loading.
|
|
1848
|
+
|
|
1849
|
+
Constructor, forward method, and differentiation of the output (or a function of the output of this module) are distributed synchronization points. Take that into account in case different processes might be executing different code.
|
|
1850
|
+
|
|
1851
|
+
This module assumes all parameters are registered in the model by the time it is created. No parameters should be added nor removed later. Same applies to buffers.
|
|
1852
|
+
|
|
1853
|
+
This module assumes all parameters are registered in the model of each distributed processes are in the same order. The module itself will conduct gradient allreduce following the reverse order of the registered parameters of the model. In other words, it is users’ responsibility to ensure that each distributed process has the exact same model and thus the exact same parameter registration order.
|
|
1854
|
+
|
|
1855
|
+
This module allows parameters with non-rowmajor-contiguous strides. For example, your model may contain some parameters whose torch.memory_format is torch.contiguous_format and others whose format is torch.channels_last. However, corresponding parameters in different processes must have the same strides.
|
|
1856
|
+
|
|
1857
|
+
This module doesn’t work with torch.autograd.grad() (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters).
|
|
1858
|
+
|
|
1859
|
+
If you plan on using this module with a nccl backend or a gloo backend (that uses Infiniband), together with a DataLoader that uses multiple workers, please change the multiprocessing start method to forkserver (Python 3 only) or spawn. Unfortunately Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will likely experience deadlocks if you don’t change this setting.
|
|
1860
|
+
|
|
1861
|
+
You should never try to change your model’s parameters after wrapping up your model with DistributedDataParallel. Because, when wrapping up your model with DistributedDataParallel, the constructor of DistributedDataParallel will register the additional gradient reduction functions on all the parameters of the model itself at the time of construction. If you change the model’s parameters afterwards, gradient reduction functions no longer match the correct set of parameters.
|
|
1862
|
+
|
|
1863
|
+
Using DistributedDataParallel in conjunction with the Distributed RPC Framework is experimental and subject to change.
|
|
1864
|
+
|
|
1865
|
+
module (Module) – module to be parallelized
|
|
1866
|
+
|
|
1867
|
+
device_ids (list of int or torch.device) – CUDA devices. 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None. When device_ids is None for both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default: None)
|
|
1868
|
+
|
|
1869
|
+
CUDA devices. 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None.
|
|
1870
|
+
|
|
1871
|
+
When device_ids is None for both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default: None)
|
|
1872
|
+
|
|
1873
|
+
output_device (int or torch.device) – Device location of output for single-device CUDA modules. For multi-device modules and CPU modules, it must be None, and the module itself dictates the output location. (default: device_ids[0] for single-device modules)
|
|
1874
|
+
|
|
1875
|
+
broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the forward function. (default: True)
|
|
1876
|
+
|
|
1877
|
+
init_sync (bool) – Whether to sync during initialization to verify param shapes and broadcast parameters and buffers. WARNING: if this is set to False the user is required to ensure themselves that the weights are the same on all ranks. (default: True)
|
|
1878
|
+
|
|
1879
|
+
process_group – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by torch.distributed.init_process_group(), will be used. (default: None)
|
|
1880
|
+
|
|
1881
|
+
bucket_cap_mb – DistributedDataParallel will bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MebiBytes (MiB). If None, a default size of 25 MiB will be used. (default: None)
|
|
1882
|
+
|
|
1883
|
+
find_unused_parameters (bool) – Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s forward function. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready to be reduced. In addition, parameters that may have been used in the wrapped module’s forward function but were not part of loss computation and thus would also not receive gradients are preemptively marked as ready to be reduced. (default: False)
|
|
1884
|
+
|
|
1885
|
+
check_reduction – This argument is deprecated.
|
|
1886
|
+
|
|
1887
|
+
gradient_as_bucket_view (bool) – When set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. Moreover, it avoids the overhead of copying between gradients and allreduce communication buckets. When gradients are views, detach_() cannot be called on the gradients. If hitting such errors, please fix it by referring to the zero_grad() function in torch/optim/optimizer.py as a solution. Note that gradients will be views after first iteration, so the peak memory saving should be checked after first iteration.
|
|
1888
|
+
|
|
1889
|
+
static_graph (bool) – When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteration to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well. Example::>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) >>> # Training loop >>> ... >>> ddp_logging_data = model_DDP._get_ddp_logging_data() >>> static_graph = ddp_logging_data.get("can_set_static_graph")
|
|
1890
|
+
|
|
1891
|
+
When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteration to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well.
|
|
1892
|
+
|
|
1893
|
+
delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter) – a list of named parameters whose all reduce will be delayed when the gradient of the parameter specified in param_to_hook_all_reduce is ready. Other arguments of DDP do not apply to named params specified in this argument as these named params will be ignored by DDP reducer.
|
|
1894
|
+
|
|
1895
|
+
param_to_hook_all_reduce (torch.nn.Parameter) – a parameter to hook delayed all reduce of parameters specified in delay_all_reduce_named_params.
|
|
1896
|
+
|
|
1897
|
+
skip_all_reduce_unused_params – When set to True, DDP will skip reducing unused parameters. This requires that unused parameters remain the same across all ranks throughout the entire training process. If this condition is not met, it may cause desynchronization and result in training hang.
|
|
1898
|
+
|
|
1899
|
+
module (Module) – the module to be parallelized.
|
|
1900
|
+
|
|
1901
|
+
Context manager for training with uneven inputs across processes in DDP.
|
|
1902
|
+
|
|
1903
|
+
This context manager will keep track of already-joined DDP processes, and “shadow” the forward and backward passes by inserting collective communication operations to match with the ones created by non-joined DDP processes. This will ensure each collective call has a corresponding call by already-joined DDP processes, preventing hangs or errors that would otherwise happen when training with uneven inputs across processes. Alternatively, if the flag throw_on_early_termination is specified to be True, all trainers will throw an error once one rank runs out of inputs, allowing these errors to be caught and handled according to application logic.
|
|
1904
|
+
|
|
1905
|
+
Once all DDP processes have joined, the context manager will broadcast the model corresponding to the last joined process to all processes to ensure the model is the same across all processes (which is guaranteed by DDP).
|
|
1906
|
+
|
|
1907
|
+
To use this to enable training with uneven inputs across processes, simply wrap this context manager around your training loop. No further modifications to the model or data loading is required.
|
|
1908
|
+
|
|
1909
|
+
If the model or training loop this context manager is wrapped around has additional distributed collective operations, such as SyncBatchNorm in the model’s forward pass, then the flag throw_on_early_termination must be enabled. This is because this context manager is not aware of non-DDP collective communication. This flag will cause all ranks to throw when any one rank exhausts inputs, allowing these errors to be caught and recovered from across all ranks.
|
|
1910
|
+
|
|
1911
|
+
divide_by_initial_world_size (bool) – If True, will divide gradients by the initial world_size DDP training was launched with. If False, will compute the effective world size (number of ranks that have not depleted their inputs yet) and divide gradients by that during allreduce. Set divide_by_initial_world_size=True to ensure every input sample including the uneven inputs have equal weight in terms of how much they contribute to the global gradient. This is achieved by always dividing the gradient by the initial world_size even when we encounter uneven inputs. If you set this to False, we divide the gradient by the remaining number of nodes. This ensures parity with training on a smaller world_size although it also means the uneven inputs would contribute more towards the global gradient. Typically, you would want to set this to True for cases where the last few inputs of your training job are uneven. In extreme cases, where there is a large discrepancy in the number of inputs, setting this to False might provide better results.
|
|
1912
|
+
|
|
1913
|
+
enable (bool) – Whether to enable uneven input detection or not. Pass in enable=False to disable in cases where you know that inputs are even across participating processes. Default is True.
|
|
1914
|
+
|
|
1915
|
+
throw_on_early_termination (bool) – Whether to throw an error or continue training when at least one rank has exhausted inputs. If True, will throw upon the first rank reaching end of data. If False, will continue training with a smaller effective world size until all ranks are joined. Note that if this flag is specified, then the flag divide_by_initial_world_size would be ignored. Default is False.
|
|
1916
|
+
|
|
1917
|
+
DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
|
|
1918
|
+
|
|
1919
|
+
kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs.
|
|
1920
|
+
|
|
1921
|
+
If True, then gradients are divided by the initial world size that DDP was launched with. If False, then gradients are divided by the effective world size (i.e. the number of non-joined processes), meaning that the uneven inputs contribute more toward the global gradient. Typically, this should be set to True if the degree of unevenness is small but can be set to False in extreme cases for possibly better results. Default is True.
|
|
1922
|
+
|
|
1923
|
+
Context manager to disable gradient synchronizations across DDP processes.
|
|
1924
|
+
|
|
1925
|
+
Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.
|
|
1926
|
+
|
|
1927
|
+
The forward pass should be included inside the context manager, or else gradients will still be synchronized.
|
|
1928
|
+
|
|
1929
|
+
Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
|
|
1930
|
+
|
|
1931
|
+
This hook would be very useful for researchers to try out new ideas. For example, this hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training.
|
|
1932
|
+
|
|
1933
|
+
state (object) – Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker.
|
|
1934
|
+
|
|
1935
|
+
Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc.
|
|
1936
|
+
|
|
1937
|
+
It is locally stored by each worker and shared by all the gradient tensors on the worker.
|
|
1938
|
+
|
|
1939
|
+
hook (Callable) – Callable with the following signature: hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn’t perform any communication, it still must return a completed Future. The Future should hold the new value of grad bucket’s tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. Note that the future’s return type must be a single tensor. We also provide an API called get_future to retrieve a Future associated with the completion of c10d.ProcessGroup.Work. get_future is currently supported for NCCL and also supported for most operations on GLOO and MPI, except for peer to peer operations (send/recv).
|
|
1940
|
+
|
|
1941
|
+
Callable with the following signature: hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
|
|
1942
|
+
|
|
1943
|
+
This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn’t perform any communication, it still must return a completed Future. The Future should hold the new value of grad bucket’s tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. Note that the future’s return type must be a single tensor.
|
|
1944
|
+
|
|
1945
|
+
We also provide an API called get_future to retrieve a Future associated with the completion of c10d.ProcessGroup.Work. get_future is currently supported for NCCL and also supported for most operations on GLOO and MPI, except for peer to peer operations (send/recv).
|
|
1946
|
+
|
|
1947
|
+
Grad bucket’s tensors will not be predivided by world_size. User is responsible to divide by the world_size in case of operations like allreduce.
|
|
1948
|
+
|
|
1949
|
+
DDP communication hook can only be registered once and should be registered before calling backward.
|
|
1950
|
+
|
|
1951
|
+
The Future object that hook returns should contain a single tensor that has the same shape with the tensors inside grad bucket.
|
|
1952
|
+
|
|
1953
|
+
get_future API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a torch.futures.Future.
|
|
1954
|
+
|
|
1955
|
+
Below is an example of a noop hook that returns the same tensor.
|
|
1956
|
+
|
|
1957
|
+
Below is an example of a Parallel SGD algorithm where gradients are encoded before allreduce, and then decoded after allreduce.
|
|
1958
|
+
|
|
1959
|
+
---
|
|
1960
|
+
|
|
1961
|
+
## DDP Communication Hooks#
|
|
1962
|
+
|
|
1963
|
+
**URL:** https://pytorch.org/docs/stable/ddp_comm_hooks.html
|
|
1964
|
+
|
|
1965
|
+
**Contents:**
|
|
1966
|
+
- DDP Communication Hooks#
|
|
1967
|
+
- How to Use a Communication Hook?#
|
|
1968
|
+
- What Does a Communication Hook Operate On?#
|
|
1969
|
+
- Default Communication Hooks#
|
|
1970
|
+
- PowerSGD Communication Hook#
|
|
1971
|
+
- PowerSGD State#
|
|
1972
|
+
- PowerSGD Hooks#
|
|
1973
|
+
- Debugging Communication Hooks#
|
|
1974
|
+
- Checkpointing of Communication Hooks#
|
|
1975
|
+
- Acknowledgements#
|
|
1976
|
+
|
|
1977
|
+
Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025
|
|
1978
|
+
|
|
1979
|
+
DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in DistributedDataParallel. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases.
|
|
1980
|
+
|
|
1981
|
+
To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below.
|
|
1982
|
+
|
|
1983
|
+
torch.nn.parallel.DistributedDataParallel.register_comm_hook()
|
|
1984
|
+
|
|
1985
|
+
A communication hook provides a flexible way to allreduce gradients. Therefore, it mainly operates on the gradients on each replica before allreduce, which are bucketized to increase the overlap between communication and computation. Particularly, torch.distributed.GradBucket represents a bucket of gradient tensors to be allreduced.
|
|
1986
|
+
|
|
1987
|
+
This class mainly passes a flattened gradient tensor (returned by buffer()) to DDP communication hook. This tensor can be further decomposed into a list of per-parameter tensors within this bucket (returned by get_per_parameter_tensors()) to apply layer-wise operations.
|
|
1988
|
+
|
|
1989
|
+
Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training.
|
|
1990
|
+
|
|
1991
|
+
The index of a bucket that stores gradients of a few contiguous layers. All the gradients are bucketized.
|
|
1992
|
+
|
|
1993
|
+
A flattened 1D torch.Tensor buffer, which can be further decomposed into a list of per-parameter tensors within this bucket.
|
|
1994
|
+
|
|
1995
|
+
A list of torch.Tensor. Each tensor in the list corresponds to a gradient.
|
|
1996
|
+
|
|
1997
|
+
Whether this bucket is the last bucket to allreduce in an iteration. This also means that this bucket corresponds to the first few layers in the forward pass.
|
|
1998
|
+
|
|
1999
|
+
Replaces the tensor in the bucket with the input tensor buffer.
|
|
2000
|
+
|
|
2001
|
+
A list of torch.Tensor. Each tensor in the list corresponds to a model parameter.
|
|
2002
|
+
|
|
2003
|
+
Default communication hooks are simple stateless hooks, so the input state in register_comm_hook is either a process group or None. The input bucket is a torch.distributed.GradBucket object.
|
|
2004
|
+
|
|
2005
|
+
Call allreduce using GradBucket tensors.
|
|
2006
|
+
|
|
2007
|
+
Once gradient tensors are aggregated across all workers, its then callback takes the mean and returns the result.
|
|
2008
|
+
|
|
2009
|
+
If user registers this DDP communication hook, DDP results is expected to be same as the case where no hook was registered. Hence, this won’t change behavior of DDP and user can use this as a reference or modify this hook to log useful information or any other purposes while unaffecting DDP behavior.
|
|
2010
|
+
|
|
2011
|
+
Compress by casting GradBucket to torch.float16 divided by process group size.
|
|
2012
|
+
|
|
2013
|
+
This DDP communication hook implements a simple gradient compression approach that casts GradBucket tensor to half-precision floating-point format (torch.float16) and then divides it by the process group size. It allreduces those float16 gradient tensors. Once compressed gradient tensors are allreduced, the chained callback decompress casts it back to the input data type (such as float32).
|
|
2014
|
+
|
|
2015
|
+
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
|
|
2016
|
+
|
|
2017
|
+
This DDP communication hook implements a simple gradient compression approach that casts GradBucket tensor to half-precision Brain floating point format (torch.bfloat16) and then divides it by the process group size. It allreduces those bfloat16 gradient tensors. Once compressed gradient tensors are allreduced, the chained callback decompress casts it back to the input data type (such as float32).
|
|
2018
|
+
|
|
2019
|
+
Additionally, a communication hook wrapper is provided to support fp16_compress_hook() or bf16_compress_hook() as a wrapper, which can be combined with other communication hooks.
|
|
2020
|
+
|
|
2021
|
+
Cast input tensor to torch.float16, cast result of hook back to input dtype.
|
|
2022
|
+
|
|
2023
|
+
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision floating point format (torch.float16), and casts the resulting tensor of the given hook back to the input data type, such as float32. Therefore, fp16_compress_hook is equivalent to fp16_compress_wrapper(allreduce_hook).
|
|
2024
|
+
|
|
2025
|
+
Callable[[Any, GradBucket], Future[Tensor]]
|
|
2026
|
+
|
|
2027
|
+
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
|
|
2028
|
+
|
|
2029
|
+
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision Brain floating point format (torch.bfloat16), and casts the resulting tensor of the given hook back to the input data type, such as float32.
|
|
2030
|
+
|
|
2031
|
+
Therefore, bf16_compress_hook is equivalent to bf16_compress_wrapper(allreduce_hook).
|
|
2032
|
+
|
|
2033
|
+
Callable[[Any, GradBucket], Future[Tensor]]
|
|
2034
|
+
|
|
2035
|
+
PowerSGD (Vogels et al., NeurIPS 2019) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a stateful hook, and the user needs to provide a state object defined as below.
|
|
2036
|
+
|
|
2037
|
+
Store both the algorithm’s hyperparameters and internal state for all gradients during training.
|
|
2038
|
+
|
|
2039
|
+
Particularly, matrix_approximation_rank and start_powerSGD_iter are the main hyperparameters that should be tuned by the user. For performance, we suggest to keep binary hyperparameters use_error_feedback and warm_start on.
|
|
2040
|
+
|
|
2041
|
+
matrix_approximation_rank controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
|
|
2042
|
+
|
|
2043
|
+
1.1. If matrix_approximation_rank is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
|
|
2044
|
+
|
|
2045
|
+
1.2. The increase of matrix_approximation_rank can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain matrix_approximation_rank threshold.
|
|
2046
|
+
|
|
2047
|
+
To tune matrix_approximation_rank, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, …), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
|
|
2048
|
+
|
|
2049
|
+
start_powerSGD_iter defers PowerSGD compression until step start_powerSGD_iter, and vanilla allreduce runs prior to step start_powerSGD_iter. This hybrid scheme of vanilla allreduce + PowerSGD can effectively improve the accuracy, even a relatively small matrix_approximation_rank is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
|
|
2050
|
+
|
|
2051
|
+
To tune start_powerSGD_iter, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, start_powerSGD_iter typically should be no less than the number of warm-up steps.
|
|
2052
|
+
|
|
2053
|
+
min_compression_rate is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.
|
|
2054
|
+
|
|
2055
|
+
Compression statistics are logged every compression_stats_logging_frequency iterations once PowerSGD compression starts.
|
|
2056
|
+
|
|
2057
|
+
orthogonalization_epsilon can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.
|
|
2058
|
+
|
|
2059
|
+
batch_tensors_with_same_shape controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., bucket_cap_mb arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to True if the compression / decompression computation is a bottleneck.
|
|
2060
|
+
|
|
2061
|
+
If error feedback or warm-up is enabled, the minimum value of start_powerSGD_iter allowed in DDP is 2. This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, and this can conflict with any tensor memorized before the rebuild process.
|
|
2062
|
+
|
|
2063
|
+
PowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy.
|
|
2064
|
+
|
|
2065
|
+
PowerSGD hooks may conflict with Apex automatic mixed precision package. Please use PyTorch native automatic mixed precision package instead.
|
|
2066
|
+
|
|
2067
|
+
Implement PowerSGD algorithm.
|
|
2068
|
+
|
|
2069
|
+
This DDP communication hook implements PowerSGD gradient compression algorithm described in the paper. Once gradient tensors are aggregated across all workers, this hook applies compression as follows:
|
|
2070
|
+
|
|
2071
|
+
Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:
|
|
2072
|
+
|
|
2073
|
+
1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.
|
|
2074
|
+
|
|
2075
|
+
1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).
|
|
2076
|
+
|
|
2077
|
+
Handles uncompressed tensors:
|
|
2078
|
+
|
|
2079
|
+
2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;
|
|
2080
|
+
|
|
2081
|
+
2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.
|
|
2082
|
+
|
|
2083
|
+
Handles the tensors that should be compressed by PowerSGD compression:
|
|
2084
|
+
|
|
2085
|
+
3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
|
2086
|
+
|
|
2087
|
+
3.2. Computes each P in Ps, which is equal to MQ;
|
|
2088
|
+
|
|
2089
|
+
3.3. Allreduces Ps as a batch;
|
|
2090
|
+
|
|
2091
|
+
3.4. Orthogonalizes each P in Ps;
|
|
2092
|
+
|
|
2093
|
+
3.5. Computes each Q in Qs, which is approximately equal to M^TP;
|
|
2094
|
+
|
|
2095
|
+
3.6. Allreduces Qs as a batch;
|
|
2096
|
+
|
|
2097
|
+
3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.
|
|
2098
|
+
|
|
2099
|
+
Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
|
2100
|
+
|
|
2101
|
+
state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank, start_powerSGD_iter and min_compression_rate.
|
|
2102
|
+
|
|
2103
|
+
bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode, only exactly one tensor is stored in this bucket.
|
|
2104
|
+
|
|
2105
|
+
Future handler of the communication, which updates the gradients in place.
|
|
2106
|
+
|
|
2107
|
+
Implement simplified PowerSGD algorithm.
|
|
2108
|
+
|
|
2109
|
+
This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in the paper. This variant does not compress the gradients layer by layer, but instead compresses the flattened input tensor that batches all the gradients. Therefore, it is faster than powerSGD_hook(), but usually results in a much lower accuracy, unless matrix_approximation_rank is 1.
|
|
2110
|
+
|
|
2111
|
+
Increasing matrix_approximation_rank here may not necessarily increase the accuracy, because batching per-parameter tensors without column/row alignment can destroy low-rank structure. Therefore, the user should always consider powerSGD_hook() first, and only consider this variant when a satisfactory accuracy can be achieved when matrix_approximation_rank is 1.
|
|
2112
|
+
|
|
2113
|
+
Once gradient tensors are aggregated across all workers, this hook applies compression as follows:
|
|
2114
|
+
|
|
2115
|
+
Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
|
|
2116
|
+
|
|
2117
|
+
Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
|
2118
|
+
|
|
2119
|
+
Computes P, which is equal to MQ;
|
|
2120
|
+
|
|
2121
|
+
Computes Q, which is approximately equal to M^TP;
|
|
2122
|
+
|
|
2123
|
+
Computes M, which is approximately equal to PQ^T.
|
|
2124
|
+
|
|
2125
|
+
Truncates the input tensor to the original length.
|
|
2126
|
+
|
|
2127
|
+
Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
|
2128
|
+
|
|
2129
|
+
state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank and start_powerSGD_iter.
|
|
2130
|
+
|
|
2131
|
+
bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode, only exactly one tensor is stored in this bucket.
|
|
2132
|
+
|
|
2133
|
+
Future handler of the communication, which updates the gradients in place.
|
|
2134
|
+
|
|
2135
|
+
As the name implies, debugging communication hooks are only used for debugging and performance optimization purpose.
|
|
2136
|
+
|
|
2137
|
+
Debugging communication hooks do not necessarily output the correct results.
|
|
2138
|
+
|
|
2139
|
+
Return a future that wraps the input, so it is a no-op that does not incur any communication overheads.
|
|
2140
|
+
|
|
2141
|
+
This hook should only be used for headroom analysis of allreduce optimization, instead of the normal gradient synchronization. For example, if only less than 10% speedup of training time can be observed after this hook is registered, it usually implies that allreduce is not a performance bottleneck for this case. Such instrumentation can be particularly useful if GPU traces cannot be easily retrieved or the trace analysis is complicated some factors such as the overlap between allreduce and computation or the desynchronization across ranks.
|
|
2142
|
+
|
|
2143
|
+
A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts. To make a hook serializable, __setstate__ and __getstate__ should be defined.
|
|
2144
|
+
|
|
2145
|
+
__getstate__ should exclude non-serializable attributes from a returned dictionary.
|
|
2146
|
+
|
|
2147
|
+
__setstate__ should properly initialize non-serializable attributes, excluded from a provided state.
|
|
2148
|
+
|
|
2149
|
+
PowerSGDState has __setstate__ and __getstate__ implemented and can be used as a reference.
|
|
2150
|
+
|
|
2151
|
+
Return a Dict[str, Any] which will be pickled and saved.
|
|
2152
|
+
|
|
2153
|
+
process_group is not serializable and excluded from a returned state.
|
|
2154
|
+
|
|
2155
|
+
Take a provided state and set to this PowerSGDState instance.
|
|
2156
|
+
|
|
2157
|
+
process_group is set to default.
|
|
2158
|
+
|
|
2159
|
+
Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.
|
|
2160
|
+
|
|
2161
|
+
Many thanks to PowerSGD paper author Thijs Vogels for the code review on PowerSGD communication hook, as well as the comparison experiments, which show that the performance of PowerSGD communication hook is on par with the implementation in the original paper.
|
|
2162
|
+
|
|
2163
|
+
---
|
|
2164
|
+
|
|
2165
|
+
## Distributed Checkpoint - torch.distributed.checkpoint#
|
|
2166
|
+
|
|
2167
|
+
**URL:** https://pytorch.org/docs/stable/distributed.checkpoint.html
|
|
2168
|
+
|
|
2169
|
+
**Contents:**
|
|
2170
|
+
- Distributed Checkpoint - torch.distributed.checkpoint#
|
|
2171
|
+
- Additional resources:#
|
|
2172
|
+
|
|
2173
|
+
Created On: Nov 16, 2022 | Last Updated On: Sep 04, 2025
|
|
2174
|
+
|
|
2175
|
+
Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. It handles load-time resharding which enables saving in one cluster topology and loading into another.
|
|
2176
|
+
|
|
2177
|
+
DCP is different than torch.save and torch.load in a few significant ways:
|
|
2178
|
+
|
|
2179
|
+
It produces multiple files per checkpoint, with at least one per rank.
|
|
2180
|
+
|
|
2181
|
+
It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
|
|
2182
|
+
|
|
2183
|
+
The entrypoints to load and save a checkpoint are the following:
|
|
2184
|
+
|
|
2185
|
+
Getting Started with Distributed Checkpoint (DCP)
|
|
2186
|
+
|
|
2187
|
+
Asynchronous Saving with Distributed Checkpoint (DCP)
|
|
2188
|
+
|
|
2189
|
+
TorchTitan Checkpointing Docs
|
|
2190
|
+
|
|
2191
|
+
TorchTitan DCP Implementation
|
|
2192
|
+
|
|
2193
|
+
Enum for async checkpointer type.
|
|
2194
|
+
|
|
2195
|
+
This class contains futures for staging and upload completion. It is returned by async_save(). staging_completion is a future that indicates when local copy of state_dict is complete. upload_completion is a future that indicates when a checkpoint completed saving.
|
|
2196
|
+
|
|
2197
|
+
Save a distributed model in SPMD style.
|
|
2198
|
+
|
|
2199
|
+
This function is different from torch.save() as it handles ShardedTensor , and DTensor by having each rank only save their local shards.
|
|
2200
|
+
|
|
2201
|
+
For each Stateful object (having both a state_dict and a load_state_dict), save will call state_dict before serialization.
|
|
2202
|
+
|
|
2203
|
+
There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts.
|
|
2204
|
+
|
|
2205
|
+
If using the process_group argument, make sure that only its ranks call save_state_dict and that all data in state_dict belong to it.
|
|
2206
|
+
|
|
2207
|
+
When saving checkpoint for FSDP’s ShardingStrategy.HYBRID_SHARD, only one of the shard_group should be calling save_state_dict and the corresponding process group needs to be passed in.
|
|
2208
|
+
|
|
2209
|
+
state_dict in the local process.
|
|
2210
|
+
|
|
2211
|
+
state_dict (Dict[str, Any]) – The state_dict to save.
|
|
2212
|
+
|
|
2213
|
+
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None)
|
|
2214
|
+
|
|
2215
|
+
storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform writes. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: None)
|
|
2216
|
+
|
|
2217
|
+
planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the default planner will be used. (Default: None)
|
|
2218
|
+
|
|
2219
|
+
process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default: None)
|
|
2220
|
+
|
|
2221
|
+
no_dist (bool) – If True, this function will assume the intent is to load a checkpoint on a single rank/process. (Default: False)
|
|
2222
|
+
|
|
2223
|
+
use_collectives (bool) – If False, this function will assume the intent is to save a checkpoint without using cross-rank synchronization. (Default: True) This configuration is experimental and should be used with caution. It will change the format of the saved checkpoint and may not be backward compatible.
|
|
2224
|
+
|
|
2225
|
+
Metadata object for the saved checkpoint.
|
|
2226
|
+
|
|
2227
|
+
save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
2228
|
+
|
|
2229
|
+
Asynchronous version of save. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the save in a separate thread.
|
|
2230
|
+
|
|
2231
|
+
This feature is experimental and subject to change. MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED
|
|
2232
|
+
|
|
2233
|
+
state_dict (Dict[str, Any]) – The state_dict to save.
|
|
2234
|
+
|
|
2235
|
+
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None)
|
|
2236
|
+
|
|
2237
|
+
storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform ‘stage’ and ‘save’. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: None)
|
|
2238
|
+
|
|
2239
|
+
planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the default planner will be used. (Default: None)
|
|
2240
|
+
|
|
2241
|
+
process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default: None)
|
|
2242
|
+
|
|
2243
|
+
async_checkpointer_type (AsyncCheckpointerType) – whether to do checkpoint in separate thread or process (Default: AsyncCheckpointerType.THREAD)
|
|
2244
|
+
|
|
2245
|
+
async_stager (AsyncStager) – provides staging implementation. If storage_writer implements AsyncStager and async_stager is provided, async_stager will be used for staging
|
|
2246
|
+
|
|
2247
|
+
no_dist (bool) – If True, this function will assume the intent is to save a checkpoint on a single rank/process. (Default: False)
|
|
2248
|
+
|
|
2249
|
+
use_collectives (bool) – If False, Save the checkpoint without rank coordination. (Default: True) This configuration is experimental and should be used with caution. It will change the format of the saved checkpoint and may not be backward compatible.
|
|
2250
|
+
|
|
2251
|
+
A future holding the resultant Metadata object from save.
|
|
2252
|
+
|
|
2253
|
+
This method is deprecated. Please switch to ‘save’.
|
|
2254
|
+
|
|
2255
|
+
Load a checkpoint into a distributed state dict in SPMD style.
|
|
2256
|
+
|
|
2257
|
+
Each rank must have the same keys in their state_dict provided to this API. Mismatched keys may result in hangs or errors. If unsure, you can use the utils._assert_same_keys API to check (but may incur communication costs).
|
|
2258
|
+
|
|
2259
|
+
Each rank will try to read the least amount of data necessary to fulfill the requested state_dict. When loading ShardedTensor or DTensor instances, each rank only reads data for their local shards.
|
|
2260
|
+
|
|
2261
|
+
For each Stateful object (having both a state_dict and a load_state_dict), load will first call state_dict before attempting deserialization, followed by load_state_dict once the deserialization is complete. For each non-Stateful object, load will deserialize the object, and then replace it in the state_dict with the deserialized object.
|
|
2262
|
+
|
|
2263
|
+
All tensors in state_dict must be allocated on their destination device prior to calling this function.
|
|
2264
|
+
|
|
2265
|
+
All non-tensor data is loaded using torch.load() and modified in place on state_dict.
|
|
2266
|
+
|
|
2267
|
+
Users must call load_state_dict on the root module to ensure load pos-processing and non-tensor data properly propagates.
|
|
2268
|
+
|
|
2269
|
+
state_dict (Dict[str, Any]) – The state_dict to load the checkpoint into.
|
|
2270
|
+
|
|
2271
|
+
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None)
|
|
2272
|
+
|
|
2273
|
+
storage_reader (Optional[StorageReader]) – Instance of StorageWriter used to perform reads. If this is not specified, DCP will automatically infer the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: None)
|
|
2274
|
+
|
|
2275
|
+
planner (Optional[LoadPlanner]) – Instance of LoadPlanner. If this is not specified, the default planner will be used. (Default: None)
|
|
2276
|
+
|
|
2277
|
+
process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default: None)
|
|
2278
|
+
|
|
2279
|
+
no_dist (bool) – If True, this function will assume the intent is to load a checkpoint without using cross-rank synchronization. (Default: False)
|
|
2280
|
+
|
|
2281
|
+
load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
|
|
2282
|
+
|
|
2283
|
+
This method is deprecated. Please switch to ‘load’.
|
|
2284
|
+
|
|
2285
|
+
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (torch.distributed.checkpoint.async_save):
|
|
2286
|
+
|
|
2287
|
+
This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users to customize how data is staged previous to executing the usual dcp.save path in parallel. The expected order of operations (concretely defined in torch.distributed.state_dict_saver.async_save) is the following:
|
|
2288
|
+
|
|
2289
|
+
This call gives the AsyncStager the opportunity to ‘stage’ the state_dict. The expectation and purpose of staging in this context is to create a “training-safe” representation of the state dict, meaning that any updates to module data after staging is complete should not be reflected in the state dict returned from this method. For example, in the default case a copy of the entire state dict is created on CPU RAM and returned here, allowing users to continue training without risking changes to data which is being serialized.
|
|
2290
|
+
|
|
2291
|
+
for serializing the state_dict and writing it to storage.
|
|
2292
|
+
|
|
2293
|
+
the serialization thread starts and before returning from dcp.async_save. If this is set to False, the assumption is the user has defined a custom synchronization point for the the purpose of further optimizing save latency in the training loop (for example, by overlapping staging with the forward/backward pass), and it is the respondsibility of the user to call AsyncStager.synchronize_staging at the appropriate time.
|
|
2294
|
+
|
|
2295
|
+
Clean up all resources used by the stager.
|
|
2296
|
+
|
|
2297
|
+
Whether to synchronize after executing the stage.
|
|
2298
|
+
|
|
2299
|
+
Returns a “staged” copy of state_dict. The expectation of the staged copy is that it is inoculated from any updates incurred after the stage call is complete.
|
|
2300
|
+
|
|
2301
|
+
Union[Future[dict[str, Union[~StatefulT, Any]]], dict[str, Union[~StatefulT, Any]]]
|
|
2302
|
+
|
|
2303
|
+
In the case stage is async in some way, this method should be called to ensure staging is complete and it is safe to begin modifying the original state_dict
|
|
2304
|
+
|
|
2305
|
+
DefaultStager provides a full-featured staging implementation that combines multiple optimization techniques for efficient checkpoint preparation.
|
|
2306
|
+
|
|
2307
|
+
The staging process works as follows: 1. State dictionary is submitted for staging (sync or async) 2. Tensors are copied from GPU to optimized CPU storage 3. CUDA operations are synchronized if non-blocking copies are used 4. Staged state dictionary is returned or made available via Future
|
|
2308
|
+
|
|
2309
|
+
# Synchronous staging stager = DefaultStager(StagingOptions(use_async_staging=False)) staged_dict = stager.stage(state_dict) stager.close()
|
|
2310
|
+
|
|
2311
|
+
# Asynchronous staging stager = DefaultStager(StagingOptions(use_async_staging=True)) future = stager.stage(state_dict) # … do other work … staged_dict = future.result() stager.close()
|
|
2312
|
+
|
|
2313
|
+
# Context manager pattern (recommended) stager = DefaultStager(config) with stager: result = stager.stage(state_dict)
|
|
2314
|
+
|
|
2315
|
+
Async staging provides best performance when model computation can overlap with staging operations
|
|
2316
|
+
|
|
2317
|
+
Pinned memory improves CPU-GPU transfer speeds but uses more memory
|
|
2318
|
+
|
|
2319
|
+
Shared memory allows efficient IPC to checkpoint process
|
|
2320
|
+
|
|
2321
|
+
Non-blocking copies reduce GPU idle time during memory transfers
|
|
2322
|
+
|
|
2323
|
+
DefaultStager is not thread-safe. Each thread should use its own instance, or external synchronization should be provided.
|
|
2324
|
+
|
|
2325
|
+
Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor used for async staging operations and cleans up the underlying StateDictStager’s cached storages. Should be called when the stager is no longer needed to prevent resource leaks, especially in long-running applications. After calling close(), the stager should not be used for further staging operations.
|
|
2326
|
+
|
|
2327
|
+
stager = DefaultStager(StagingOptions(use_async_staging=True)) future = stager.stage(state_dict) result = future.result() stager.close() # Clean up all resources
|
|
2328
|
+
|
|
2329
|
+
This function is responsible for staging staging the state_dict. See class docstring for more details on staging. If use_async_staging is True, it will return a Future object that will be fulfilled when staging is complete. If use_async_staging is False, it will return the fully staged state_dict.
|
|
2330
|
+
|
|
2331
|
+
state_dict (STATE_DICT_TYPE) – The state_dict to be staged.
|
|
2332
|
+
|
|
2333
|
+
Union[dict[str, Union[~StatefulT, Any]], Future[dict[str, Union[~StatefulT, Any]]]]
|
|
2334
|
+
|
|
2335
|
+
When use_async_staging is True, this method will wait until staging is complete. If use_async_staging is False, this method is a no-op.
|
|
2336
|
+
|
|
2337
|
+
Configuration options for checkpoint staging behavior.
|
|
2338
|
+
|
|
2339
|
+
use_pinned_memory (bool) – Enable pinned memory allocation for faster CPU-GPU transfers. Requires CUDA to be available. Default: True
|
|
2340
|
+
|
|
2341
|
+
use_shared_memory (bool) – Enable shared memory for multi-process scenarios. Useful when multiple processes need access to the same staged data. Default: True
|
|
2342
|
+
|
|
2343
|
+
use_async_staging (bool) – Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True
|
|
2344
|
+
|
|
2345
|
+
use_non_blocking_copy (bool) – Use non-blocking device memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True
|
|
2346
|
+
|
|
2347
|
+
CUDA-dependent features will raise exception if CUDA is not available.
|
|
2348
|
+
|
|
2349
|
+
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. This implementation also provides an option to optimize stage latency using pinned memory.
|
|
2350
|
+
|
|
2351
|
+
N.B. synchronize_staging is a no-op in this case.
|
|
2352
|
+
|
|
2353
|
+
Returns a copy of state_dict on the CPU.
|
|
2354
|
+
|
|
2355
|
+
dict[str, Union[~StatefulT, Any]]
|
|
2356
|
+
|
|
2357
|
+
No-op function, since staging is blocking.
|
|
2358
|
+
|
|
2359
|
+
In addition to the above entrypoints, Stateful objects, as described below, provide additional customization during saving/loading
|
|
2360
|
+
|
|
2361
|
+
Stateful protocol for objects that can be checkpointed and restored.
|
|
2362
|
+
|
|
2363
|
+
Restore the object’s state from the provided state_dict.
|
|
2364
|
+
|
|
2365
|
+
state_dict (dict[str, Any]) – The state dict to restore from
|
|
2366
|
+
|
|
2367
|
+
Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().
|
|
2368
|
+
|
|
2369
|
+
Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.
|
|
2370
|
+
|
|
2371
|
+
The objects state dict
|
|
2372
|
+
|
|
2373
|
+
This example shows how to use Pytorch Distributed Checkpoint to save a FSDP model.
|
|
2374
|
+
|
|
2375
|
+
The following types define the IO interface used during checkpoint:
|
|
2376
|
+
|
|
2377
|
+
Interface used by load_state_dict to read from storage.
|
|
2378
|
+
|
|
2379
|
+
One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.
|
|
2380
|
+
|
|
2381
|
+
A subclass should expected the following sequence of calls by load_state_dict:
|
|
2382
|
+
|
|
2383
|
+
(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
|
|
2384
|
+
|
|
2385
|
+
(all ranks) read_metadata()
|
|
2386
|
+
|
|
2387
|
+
(all ranks) set_up_storage_reader()
|
|
2388
|
+
|
|
2389
|
+
(all ranks) prepare_local_plan()
|
|
2390
|
+
|
|
2391
|
+
(coordinator) prepare_global_plan()
|
|
2392
|
+
|
|
2393
|
+
(all ranks) read_data()
|
|
2394
|
+
|
|
2395
|
+
Perform centralized planning of storage loading.
|
|
2396
|
+
|
|
2397
|
+
This method is only called on the coordinator instance.
|
|
2398
|
+
|
|
2399
|
+
While this method can produce a completely different plan, the preferred way is to store storage specific data in LoadPlan::storage_data.
|
|
2400
|
+
|
|
2401
|
+
plans (list[torch.distributed.checkpoint.planner.LoadPlan]) – A list of LoadPlan instances, one for each rank.
|
|
2402
|
+
|
|
2403
|
+
A list of transformed LoadPlan after storage global planning
|
|
2404
|
+
|
|
2405
|
+
list[torch.distributed.checkpoint.planner.LoadPlan]
|
|
2406
|
+
|
|
2407
|
+
Perform storage-specific local planning.
|
|
2408
|
+
|
|
2409
|
+
While this method can produce a completely different plan, the recommended way is to store storage specific data in LoadPlan::storage_data.
|
|
2410
|
+
|
|
2411
|
+
plan (LoadPlan) – The local plan from the LoadPlan in use.
|
|
2412
|
+
|
|
2413
|
+
A transformed LoadPlan after storage local planning
|
|
2414
|
+
|
|
2415
|
+
Read all items from plan using planner to resolve the data.
|
|
2416
|
+
|
|
2417
|
+
A subclass should call LoadPlanner::load_bytes to deserialize a BytesIO object into the right place.
|
|
2418
|
+
|
|
2419
|
+
A subclass should call LoadPlanner::resolve_tensor to get access to the tensors that in should load data into.
|
|
2420
|
+
|
|
2421
|
+
It’s the StorageLayer responsibility to properly schedule any cross device copies required.
|
|
2422
|
+
|
|
2423
|
+
plan (LoadPlan) – The local plan to execute on
|
|
2424
|
+
|
|
2425
|
+
planner (LoadPlanner) – The planner object to use to resolve items.
|
|
2426
|
+
|
|
2427
|
+
A future that completes once all reads are finished.
|
|
2428
|
+
|
|
2429
|
+
Read the checkpoint metadata.
|
|
2430
|
+
|
|
2431
|
+
The metadata object associated with the checkpoint being loaded.
|
|
2432
|
+
|
|
2433
|
+
Calls to indicates a brand new checkpoint read is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
|
|
2434
|
+
|
|
2435
|
+
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is more like a key-value store. (Default: None)
|
|
2436
|
+
|
|
2437
|
+
Initialize this instance.
|
|
2438
|
+
|
|
2439
|
+
metadata (Metadata) – The metadata schema to use.
|
|
2440
|
+
|
|
2441
|
+
is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint.
|
|
2442
|
+
|
|
2443
|
+
Check if the given checkpoint_id is supported by the storage. This allow us to enable automatic storage selection.
|
|
2444
|
+
|
|
2445
|
+
Interface used by save_state_dict to write to storage.
|
|
2446
|
+
|
|
2447
|
+
One StorageWriter instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.
|
|
2448
|
+
|
|
2449
|
+
A subclass should expect the following sequence of calls.
|
|
2450
|
+
|
|
2451
|
+
(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
|
|
2452
|
+
|
|
2453
|
+
(all ranks) set_up_storage_writer()
|
|
2454
|
+
|
|
2455
|
+
(all ranks) prepare_local_plan()
|
|
2456
|
+
|
|
2457
|
+
(coordinator) prepare_global_plan()
|
|
2458
|
+
|
|
2459
|
+
(all ranks) write_data()
|
|
2460
|
+
|
|
2461
|
+
(coordinator) finish()
|
|
2462
|
+
|
|
2463
|
+
Write the metadata and marks the current checkpoint as successful.
|
|
2464
|
+
|
|
2465
|
+
The actual format/schema used for serializing metadata is an implementation detail. The only requirement is that it’s recoverable in to the same object graph.
|
|
2466
|
+
|
|
2467
|
+
metadata (Metadata) – metadata for the new checkpoint
|
|
2468
|
+
|
|
2469
|
+
results (list[list[torch.distributed.checkpoint.storage.WriteResult]]) – A list of WriteResults from all ranks.
|
|
2470
|
+
|
|
2471
|
+
Perform centralized planning of storage.
|
|
2472
|
+
|
|
2473
|
+
This method is only called on the coordinator instance.
|
|
2474
|
+
|
|
2475
|
+
While this method can produce a completely different plan, the preferred way is to store storage specific data in SavePlan::storage_data.
|
|
2476
|
+
|
|
2477
|
+
plans (list[torch.distributed.checkpoint.planner.SavePlan]) – A list of SavePlan instances, one for each rank.
|
|
2478
|
+
|
|
2479
|
+
A list of transformed SavePlan after storage global planning
|
|
2480
|
+
|
|
2481
|
+
list[torch.distributed.checkpoint.planner.SavePlan]
|
|
2482
|
+
|
|
2483
|
+
Perform storage-specific local planning.
|
|
2484
|
+
|
|
2485
|
+
While this method can produce a completely different plan, the recommended way is to store storage specific data in SavePlan::storage_data.
|
|
2486
|
+
|
|
2487
|
+
plan (SavePlan) – The local plan from the SavePlanner in use.
|
|
2488
|
+
|
|
2489
|
+
A transformed SavePlan after storage local planning
|
|
2490
|
+
|
|
2491
|
+
Calls to indicates a brand new checkpoint write is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint write. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
|
|
2492
|
+
|
|
2493
|
+
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None)
|
|
2494
|
+
|
|
2495
|
+
Initialize this instance.
|
|
2496
|
+
|
|
2497
|
+
is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint.
|
|
2498
|
+
|
|
2499
|
+
Return the storage-specific metadata. This is used to store additional information in a checkpoint that can be useful for providing request-level observability. StorageMeta is passed to the SavePlanner during save calls. Returns None by default.
|
|
2500
|
+
|
|
2501
|
+
TODO: provide an example
|
|
2502
|
+
|
|
2503
|
+
Optional[StorageMeta]
|
|
2504
|
+
|
|
2505
|
+
Check if the given checkpoint_id is supported by the storage. This allow us to enable automatic storage selection.
|
|
2506
|
+
|
|
2507
|
+
Write all items from plan using planner to resolve the data.
|
|
2508
|
+
|
|
2509
|
+
A subclass should call SavePlanner::resolve_data on each item from the plan to get access to the underlying object to write.
|
|
2510
|
+
|
|
2511
|
+
Subclasses should lazily call resolve_data as it can allocate memory. In case of tensors, make following assumptions:
|
|
2512
|
+
|
|
2513
|
+
They might be on any device, including not matching the one on WriteItem::tensor_data
|
|
2514
|
+
|
|
2515
|
+
They might be views or not contiguous. Only the projection needs to be saved.
|
|
2516
|
+
|
|
2517
|
+
plan (SavePlan) – The save plan to execute.
|
|
2518
|
+
|
|
2519
|
+
planner (SavePlanner) – Planner object to be used to resolve items to data.
|
|
2520
|
+
|
|
2521
|
+
A future that completes to a list of WriteResult
|
|
2522
|
+
|
|
2523
|
+
Future[list[torch.distributed.checkpoint.storage.WriteResult]]
|
|
2524
|
+
|
|
2525
|
+
The following types define the planner interface used during checkpoint:
|
|
2526
|
+
|
|
2527
|
+
Abstract class defining the protocol used by load_state_dict to plan the load process.
|
|
2528
|
+
|
|
2529
|
+
LoadPlanner are stateful objects that can be used to customize the whole load process.
|
|
2530
|
+
|
|
2531
|
+
LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.
|
|
2532
|
+
|
|
2533
|
+
A planner subclass can expect the following sequence of calls during load_state_dict:
|
|
2534
|
+
|
|
2535
|
+
Signals the start of loading a checkpoint.
|
|
2536
|
+
|
|
2537
|
+
Process the state_dict and produces a LoadPlan that will be sent for global planning.
|
|
2538
|
+
|
|
2539
|
+
Takes the LoadPlan from all ranks and make any global decision.
|
|
2540
|
+
|
|
2541
|
+
This is called once per non-tensor value in state_dict.
|
|
2542
|
+
|
|
2543
|
+
They are called in pair for each Tensor value in state_dict.
|
|
2544
|
+
|
|
2545
|
+
Users are recommended to extend DefaultLoadPlanner instead of this interface directly as most changes can be expressed by changes in a single method.
|
|
2546
|
+
|
|
2547
|
+
There are two usual patterns of extension:
|
|
2548
|
+
|
|
2549
|
+
Rewriting state_dict. This is the simplest way to extend the load process as it doesn’t requite understanding the intrincacies of how LoadPlan works. We need to keep a reference to the original state_dict as load happens in place so we need to be able to perform it in place
|
|
2550
|
+
|
|
2551
|
+
Modifying resolve_tensor and commit_tensor to handle load time transformation.
|
|
2552
|
+
|
|
2553
|
+
Call once the StorageReader finished loading data into tensor.
|
|
2554
|
+
|
|
2555
|
+
The provided tensor is the same one returned by the call to resolve_tensor. This method is only needed if this LoadPlanner needs to post process tensor prior to copying it back to the one in the state_dict.
|
|
2556
|
+
|
|
2557
|
+
The contents of tensor will follow its device synchronization model.
|
|
2558
|
+
|
|
2559
|
+
Compute the global load plan and return plans for each rank.
|
|
2560
|
+
|
|
2561
|
+
. N.B. This is called on the coordinator rank only
|
|
2562
|
+
|
|
2563
|
+
list[torch.distributed.checkpoint.planner.LoadPlan]
|
|
2564
|
+
|
|
2565
|
+
Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
|
|
2566
|
+
|
|
2567
|
+
. N.B. This is called on every rank.
|
|
2568
|
+
|
|
2569
|
+
Accept the plan from coordinator and return final LoadPlan.
|
|
2570
|
+
|
|
2571
|
+
Load the item described by read_item``and ``value.
|
|
2572
|
+
|
|
2573
|
+
This method is expected to modify in-place the underlying state_dict.
|
|
2574
|
+
|
|
2575
|
+
The contents of value are defined by the SavePlanner used to produce the checkpoint being loaded.
|
|
2576
|
+
|
|
2577
|
+
Return the BytesIO to be used by the StorageReader to load read_item.
|
|
2578
|
+
|
|
2579
|
+
The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.
|
|
2580
|
+
|
|
2581
|
+
Return the tensor described by read_item to be used by the StorageReader to load read_item.
|
|
2582
|
+
|
|
2583
|
+
The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. If, for any reason, that’s not possible, the planner can use the commit_tensor method to copy the data back to the one in state_dict.
|
|
2584
|
+
|
|
2585
|
+
Initialize this instance to load data into state_dict.
|
|
2586
|
+
|
|
2587
|
+
. N.B. This is called on every rank.
|
|
2588
|
+
|
|
2589
|
+
Abstract class defining the protocol used by save_state_dict to plan the save process.
|
|
2590
|
+
|
|
2591
|
+
SavePlanners are stateful objects that can be used to customize the whole save process.
|
|
2592
|
+
|
|
2593
|
+
SavePlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.
|
|
2594
|
+
|
|
2595
|
+
A planner subclass can expect the following sequence of calls during save_state_dict:
|
|
2596
|
+
|
|
2597
|
+
Signals the start of a checkpoint save.
|
|
2598
|
+
|
|
2599
|
+
Process the state_dict and produces a SavePlan that will be sent for global planning.
|
|
2600
|
+
|
|
2601
|
+
Takes the SavePlan from all ranks and make any global decision.
|
|
2602
|
+
|
|
2603
|
+
This gives each rank a chance to adjust to global planning decisions.
|
|
2604
|
+
|
|
2605
|
+
Lookups a value on the state_dict for the storage layer to write.
|
|
2606
|
+
|
|
2607
|
+
Users are recommended to extend DefaultSavePlanner instead of this interface directly as most changes can be expressed by changes in a single method.
|
|
2608
|
+
|
|
2609
|
+
There are 3 usual patterns of extension:
|
|
2610
|
+
|
|
2611
|
+
Rewriting state_dict. This is the simplest way to extend the save process as it doesn’t requite understanding the intrincacies of how SavePlan works:
|
|
2612
|
+
|
|
2613
|
+
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
|
|
2614
|
+
|
|
2615
|
+
Using the global planning step to make central decisions that can’t be made individually by each rank
|
|
2616
|
+
|
|
2617
|
+
Finally, some planners need to save additional metadata in the checkpoint, this is accomplished by having each rank contribute their data items in the local plan and the global planner aggregate them:
|
|
2618
|
+
|
|
2619
|
+
Compute the global checkpoint plan and return the local plan of each rank.
|
|
2620
|
+
|
|
2621
|
+
This is called on the coordinator rank only.
|
|
2622
|
+
|
|
2623
|
+
tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata]
|
|
2624
|
+
|
|
2625
|
+
Compute the save plan for the current rank.
|
|
2626
|
+
|
|
2627
|
+
This will be aggregated and passed to create_global_plan. Planner specific data can be passed through SavePlan::planner_data.
|
|
2628
|
+
|
|
2629
|
+
This is called on all ranks.
|
|
2630
|
+
|
|
2631
|
+
Merge the plan created by create_local_plan and the result of create_global_plan.
|
|
2632
|
+
|
|
2633
|
+
This is called on all ranks.
|
|
2634
|
+
|
|
2635
|
+
Transform and prepare write_item from state_dict for storage, ensuring idempotency and thread-safety.
|
|
2636
|
+
|
|
2637
|
+
Lookup the object associated with write_item in state_dict and apply any transformation (such as serialization) prior to the storage layer consuming it.
|
|
2638
|
+
|
|
2639
|
+
Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
|
|
2640
|
+
|
|
2641
|
+
This method should be idempotent and thread-save. StorageWriter implementations are free to call it as frequently as they need.
|
|
2642
|
+
|
|
2643
|
+
Any transformation that allocates memory should be lazily done when his method is called in order to reduce peak memory required by checkpointing.
|
|
2644
|
+
|
|
2645
|
+
When returning tensors, they can be on any device or format, they can be views too. It’s the storage layer responsibility to figure out how to save them.
|
|
2646
|
+
|
|
2647
|
+
Union[Tensor, BytesIO]
|
|
2648
|
+
|
|
2649
|
+
Initialize this planner to save state_dict.
|
|
2650
|
+
|
|
2651
|
+
Implementations should save those values as they won’t be provided lated in the save process.
|
|
2652
|
+
|
|
2653
|
+
This is called on all ranks.
|
|
2654
|
+
|
|
2655
|
+
Dataclass which holds information about what needs to be written to storage.
|
|
2656
|
+
|
|
2657
|
+
Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
|
|
2658
|
+
|
|
2659
|
+
Optional[int] storage size, in bytes of underlying tensor if any.
|
|
2660
|
+
|
|
2661
|
+
We provide a filesystem based storage layer:
|
|
2662
|
+
|
|
2663
|
+
return the checkpoint_id that will be used to load the checkpoint.
|
|
2664
|
+
|
|
2665
|
+
Basic implementation of StorageWriter using file IO.
|
|
2666
|
+
|
|
2667
|
+
This implementation makes the following assumptions and simplifications:
|
|
2668
|
+
|
|
2669
|
+
The checkpoint path is an empty or non-existing directory.
|
|
2670
|
+
|
|
2671
|
+
File creation is atomic
|
|
2672
|
+
|
|
2673
|
+
The checkpoint consist of one file per write request plus a global .metadata file with the serialized metadata if rank coordination is enabled. a rank local __{rank}.metadata file with the serialized metadata if rank coordination is NOT enabled.
|
|
2674
|
+
|
|
2675
|
+
Override of AsyncStager.stage
|
|
2676
|
+
|
|
2677
|
+
dict[str, Union[~StatefulT, Any]]
|
|
2678
|
+
|
|
2679
|
+
We also provide other storage layers, including ones to interact with HuggingFace safetensors:
|
|
2680
|
+
|
|
2681
|
+
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageReader :members:
|
|
2682
|
+
|
|
2683
|
+
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter :members:
|
|
2684
|
+
|
|
2685
|
+
.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader :members:
|
|
2686
|
+
|
|
2687
|
+
We provide default implementations of LoadPlanner and SavePlanner that can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.
|
|
2688
|
+
|
|
2689
|
+
Extension from the planner interface to make it easy to extend the default planner.
|
|
2690
|
+
|
|
2691
|
+
Extension from the planner interface to make it easy to extend the default planner.
|
|
2692
|
+
|
|
2693
|
+
DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
|
|
2694
|
+
|
|
2695
|
+
In particular it adds the following:
|
|
2696
|
+
|
|
2697
|
+
flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
|
|
2698
|
+
|
|
2699
|
+
Extension from the planner interface to make it easy to extend the default planner.
|
|
2700
|
+
|
|
2701
|
+
Extension from the planner interface to make it easy to extend the default planner.
|
|
2702
|
+
|
|
2703
|
+
Due to legacy design decisions, the state dictionaries of FSDP and DDP may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, FSDP offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism).
|
|
2704
|
+
|
|
2705
|
+
To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. get_model_state_dict() returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, get_optimizer_state_dict() provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, get_optimizer_state_dict() converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary.
|
|
2706
|
+
|
|
2707
|
+
Note that results returned by these APIs can be used directly with the torch.distributed.checkpoint.save() and torch.distributed.checkpoint.load() methods without requiring any additional conversions.
|
|
2708
|
+
|
|
2709
|
+
set_model_state_dict() and set_optimizer_state_dict() are provided to load the model and optimizer state_dict generated by by their respective getter APIs.
|
|
2710
|
+
|
|
2711
|
+
Note that set_optimizer_state_dict() can only be called before backward() or after step() is called on optimizers.
|
|
2712
|
+
|
|
2713
|
+
Note that this feature is experimental, and API signatures might change in the future.
|
|
2714
|
+
|
|
2715
|
+
Return the model state_dict and optimizers state_dict.
|
|
2716
|
+
|
|
2717
|
+
get_state_dict can process any module that is parallelized by PyTorch FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any combination of these parallelisms. The main functions of get_state_dict are: 1.) returning a model and optimizer state_dict that can be resharded with a different number of trainers and/or different parallelisms. 2.) hiding the parallelism-specific state_dict APIs. Users don’t have to call these APIs. 3.) sanity checking the result state_dict.
|
|
2718
|
+
|
|
2719
|
+
The keys of the result state dictionary are the canonical FQNs (Fully Qualified Names). A canonical FQN refers to the FQN based on a parameter’s position in an nn.Module hierarchy. More specifically, a canonical FQN to a parameter is the FQN returned by module.named_parameters() or module.named_buffers() when the module is not distributed by any parallelisms. Since the optimizer internally uses parameter IDs to represent a parameter, there will be a conversion from the parameter IDs to the canonical FQNs when calling this API.
|
|
2720
|
+
|
|
2721
|
+
get_state_dict can also process a module that is not parallelized. In such a case, get_state_dict only performs one function – converting the optimizer parameter IDs to the canonical FQNs.
|
|
2722
|
+
|
|
2723
|
+
model (nn.Module) – the nn.Module to the model.
|
|
2724
|
+
|
|
2725
|
+
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model.
|
|
2726
|
+
|
|
2727
|
+
submodules (deprecated) – Optional[set[nn.Module]]: only return the model parameters that belong to the submodules.
|
|
2728
|
+
|
|
2729
|
+
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.
|
|
2730
|
+
|
|
2731
|
+
Tuple that contain model state_dict and optimizer state_dict.
|
|
2732
|
+
|
|
2733
|
+
Tuple[Dict[str, ValueType], OptimizerStateType]
|
|
2734
|
+
|
|
2735
|
+
Return the model state_dict of model.
|
|
2736
|
+
|
|
2737
|
+
See get_state_dict for the detail usage.
|
|
2738
|
+
|
|
2739
|
+
model (nn.Module) – the nn.Module to the model.
|
|
2740
|
+
|
|
2741
|
+
submodules (deprecated) – Optional[set[nn.Module]]: only return the model parameters that belong to the submodules.
|
|
2742
|
+
|
|
2743
|
+
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.
|
|
2744
|
+
|
|
2745
|
+
The state_dict for model.
|
|
2746
|
+
|
|
2747
|
+
Return the combined state_dict for optimizers.
|
|
2748
|
+
|
|
2749
|
+
See get_state_dict for the detail usage.
|
|
2750
|
+
|
|
2751
|
+
model (nn.Module) – the nn.Module to the model.
|
|
2752
|
+
|
|
2753
|
+
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model.
|
|
2754
|
+
|
|
2755
|
+
submodules (deprecated) – Optional[set[nn.Module]]: only return the model parameters that belong to the submodules.
|
|
2756
|
+
|
|
2757
|
+
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.
|
|
2758
|
+
|
|
2759
|
+
The state_dict for optimizers.
|
|
2760
|
+
|
|
2761
|
+
Load the model state_dict and optimizers state_dict.
|
|
2762
|
+
|
|
2763
|
+
The counterpart of get_state_dict to set the state_dict to the model and optimizers. The given model_state_dict and optim_state_dict do not have to be returned by get_state_dict but must meet the following requirements: 1) all FQNs are canonical FQNs as defined in get_state_dict, 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, 3) optimizer state_dict cannot contain the parameter IDs; the keys should be the canonical FQNs.
|
|
2764
|
+
|
|
2765
|
+
is called on the optimizers. Otherwise, the optimizer states won’t be initialized correctly.
|
|
2766
|
+
|
|
2767
|
+
model (nn.Module) – the nn.Module to the model.
|
|
2768
|
+
|
|
2769
|
+
optimizers (Union[Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model.
|
|
2770
|
+
|
|
2771
|
+
model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): the model state_dict to load. If the key of the model_state_dict is nn.Module, the key is a submodule of model and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict.
|
|
2772
|
+
|
|
2773
|
+
optim_state_dict (OptimizerStateType) – OptimizerStateType: the optimizer state_dict to load.
|
|
2774
|
+
|
|
2775
|
+
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details.
|
|
2776
|
+
|
|
2777
|
+
missing_keys is a list of str containing the missing keys of the model state_dict. unexpected_keys is a list of str containing the unexpected keys of the model state_dict.
|
|
2778
|
+
|
|
2779
|
+
missing_keys is a list of str containing the missing keys of the model state_dict.
|
|
2780
|
+
|
|
2781
|
+
unexpected_keys is a list of str containing the unexpected keys of the model state_dict.
|
|
2782
|
+
|
|
2783
|
+
NamedTuple with missing_keys and unexpected_keys fields
|
|
2784
|
+
|
|
2785
|
+
Load the model state_dict.
|
|
2786
|
+
|
|
2787
|
+
The counterpart of get_model_state_dict to set the state_dict to the model. See set_state_dict for the detail usage.
|
|
2788
|
+
|
|
2789
|
+
model (nn.Module) – the nn.Module to the model.
|
|
2790
|
+
|
|
2791
|
+
model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): the model state_dict to load. If the key of the model_state_dict is nn.Module, the key is a submodule of model and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict.
|
|
2792
|
+
|
|
2793
|
+
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details.
|
|
2794
|
+
|
|
2795
|
+
missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys
|
|
2796
|
+
|
|
2797
|
+
missing_keys is a list of str containing the missing keys
|
|
2798
|
+
|
|
2799
|
+
unexpected_keys is a list of str containing the unexpected keys
|
|
2800
|
+
|
|
2801
|
+
NamedTuple with missing_keys and unexpected_keys fields
|
|
2802
|
+
|
|
2803
|
+
Load the optimizers state_dict.
|
|
2804
|
+
|
|
2805
|
+
The counterpart of get_optimizer_state_dict to set the state_dict to the optimizers. See set_state_dict for the detail usage.
|
|
2806
|
+
|
|
2807
|
+
step() is called on the optimizers. Otherwise, the optimizer states won’t be initialized correctly.
|
|
2808
|
+
|
|
2809
|
+
model (nn.Module) – the nn.Module to the model.
|
|
2810
|
+
|
|
2811
|
+
optimizers (Union[Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model.
|
|
2812
|
+
|
|
2813
|
+
optim_state_dict (OptimizerStateType) – OptimizerStateType: the optimizer state_dict to load.
|
|
2814
|
+
|
|
2815
|
+
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details.
|
|
2816
|
+
|
|
2817
|
+
This dataclass specifies how get_state_dict/set_state_dict will work.
|
|
2818
|
+
|
|
2819
|
+
full_state_dict: if this is set to True, all the tensors in the returned state_dict will be gathered. No ShardedTensor and DTensor will be in the returned state_dict.
|
|
2820
|
+
|
|
2821
|
+
cpu_offload: offload all the tensors to cpu. To prevent CPU OOM, if full_state_dict is also true, then only the rank0 will get the state_dict and all other ranks will get empty state_dict.
|
|
2822
|
+
|
|
2823
|
+
ignore_frozen_params: if the value is True, the returned state_dict won’t contain any frozen parameters – the requires_grad is False. The default value is False.
|
|
2824
|
+
|
|
2825
|
+
keep_submodule_prefixes (deprecated): when submodules is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule is module.pretrain and the full FQN of the parameter is pretrain.layer1.weight of the param. When this option is True, the parameter’s key in the returned state_dict will be pretrain.layer1.weight. If the options is False, the key will be layer1.weight. Note that if keep_submodule_prefixes is False, there may be conflicted FQNs, hence there should be only one submodule in submodules.
|
|
2826
|
+
|
|
2827
|
+
strict: the strict option when set_state_dict calls model.load_state_dict().
|
|
2828
|
+
|
|
2829
|
+
full state_dict and will broadcast the tensors in the state_dict/ optim_state_dict one by one to other ranks. Other ranks will receive the tensors and shard according to the local shards in the model and optimizer. full_state_dict must be set to True when using this option. This option currently only supports DTensor, not the legacy ShardedTensor.
|
|
2830
|
+
|
|
2831
|
+
For users which are used to using and sharing models in the torch.save format, the following methods are provided which provide offline utilities for converting betweeing formats.
|
|
2832
|
+
|
|
2833
|
+
Given a directory containing a DCP checkpoint, this function will convert it into a Torch save file.
|
|
2834
|
+
|
|
2835
|
+
dcp_checkpoint_dir (Union[str, PathLike]) – Directory containing the DCP checkpoint.
|
|
2836
|
+
|
|
2837
|
+
torch_save_path (Union[str, PathLike]) – Filename to store the converted Torch save file.
|
|
2838
|
+
|
|
2839
|
+
To avoid OOM, it’s recommended to only run this function on a single rank.
|
|
2840
|
+
|
|
2841
|
+
Given the location of a torch save file, converts it into a DCP checkpoint.
|
|
2842
|
+
|
|
2843
|
+
torch_save_path (Union[str, PathLike]) – Filename of the Torch save file.
|
|
2844
|
+
|
|
2845
|
+
dcp_checkpoint_dir (Union[str, PathLike]) – Directory to store the DCP checkpoint.
|
|
2846
|
+
|
|
2847
|
+
To avoid OOM, it’s recommended to only run this function on a single rank.
|
|
2848
|
+
|
|
2849
|
+
The following classes can also be utilized for online loading and resharding of models from the torch.save format.
|
|
2850
|
+
|
|
2851
|
+
StorageReader for reading a Torch Save file. This reader will read the entire checkpoint on the coordinator rank, and then broadcast and shard each tensor to all ranks.
|
|
2852
|
+
|
|
2853
|
+
. N.B. Intended to be used with DynamicMetaLoadPlanner
|
|
2854
|
+
|
|
2855
|
+
Current implementation only supports loading Tensors.
|
|
2856
|
+
|
|
2857
|
+
Implementation of the StorageReader method
|
|
2858
|
+
|
|
2859
|
+
list[torch.distributed.checkpoint.planner.LoadPlan]
|
|
2860
|
+
|
|
2861
|
+
Implementation of the StorageReader method
|
|
2862
|
+
|
|
2863
|
+
Reads torch save data on the coordinator rank, and broadcast afterwards this incurrs a communication cost, but avoids having to load the entire checkpoint on each rank, hopefully preventing OOM issues
|
|
2864
|
+
|
|
2865
|
+
Extends the default StorageReader to support building the metadata file
|
|
2866
|
+
|
|
2867
|
+
Implementation of the StorageReader method
|
|
2868
|
+
|
|
2869
|
+
Implementation of the StorageReader method
|
|
2870
|
+
|
|
2871
|
+
Implementation of the StorageReader method
|
|
2872
|
+
|
|
2873
|
+
Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, avoiding the need to read metadata from disk. This is useful when reading formats which don’t have a metadata file, like Torch Save files.
|
|
2874
|
+
|
|
2875
|
+
. N.B. Intended to be used with BroadcastingTorchSaveReader
|
|
2876
|
+
|
|
2877
|
+
Current implementation only supports loading Tensors.
|
|
2878
|
+
|
|
2879
|
+
Setups of the planner, extnding default behavior by creating the Metadata object from the state dict
|
|
2880
|
+
|
|
2881
|
+
The following experimental interfaces are provided for improved observability in production environments:
|
|
2882
|
+
|
|
2883
|
+
---
|
|
2884
|
+
|
|
2885
|
+
## torch.distributed.tensor#
|
|
2886
|
+
|
|
2887
|
+
**URL:** https://pytorch.org/docs/stable/distributed.tensor.html
|
|
2888
|
+
|
|
2889
|
+
**Contents:**
|
|
2890
|
+
- torch.distributed.tensor#
|
|
2891
|
+
- PyTorch DTensor (Distributed Tensor)#
|
|
2892
|
+
- DTensor Class APIs#
|
|
2893
|
+
- DeviceMesh as the distributed communicator#
|
|
2894
|
+
- DTensor Placement Types#
|
|
2895
|
+
- Different ways to create a DTensor#
|
|
2896
|
+
- Create DTensor from a logical torch.Tensor#
|
|
2897
|
+
- DTensor Factory Functions#
|
|
2898
|
+
- Random Operations#
|
|
2899
|
+
- Debugging#
|
|
2900
|
+
|
|
2901
|
+
Created On: Jun 13, 2025 | Last Updated On: Aug 23, 2025
|
|
2902
|
+
|
|
2903
|
+
torch.distributed.tensor is currently in alpha state and under development, we are committing backward compatibility for the most APIs listed in the doc, but there might be API changes if necessary.
|
|
2904
|
+
|
|
2905
|
+
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed logic, including sharded storage, operator computation and collective communications across devices/hosts. DTensor could be used to build different parallelism solutions and support sharded state_dict representation when working with multi-dimensional sharding.
|
|
2906
|
+
|
|
2907
|
+
Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor:
|
|
2908
|
+
|
|
2909
|
+
DTensor follows the SPMD (single program, multiple data) programming model to empower users to write distributed program as if it’s a single-device program with the same convergence property. It provides a uniform tensor sharding layout (DTensor Layout) through specifying the DeviceMesh and Placement:
|
|
2910
|
+
|
|
2911
|
+
DeviceMesh represents the device topology and the communicators of the cluster using an n-dimensional array.
|
|
2912
|
+
|
|
2913
|
+
Placement describes the sharding layout of the logical tensor on the DeviceMesh. DTensor supports three types of placements: Shard, Replicate and Partial.
|
|
2914
|
+
|
|
2915
|
+
DTensor is a torch.Tensor subclass. This means once a DTensor is created, it could be used in very similar way to torch.Tensor, including running different types of PyTorch operators as if running them in a single device, allowing proper distributed computation for PyTorch operators.
|
|
2916
|
+
|
|
2917
|
+
In addition to existing torch.Tensor methods, it also offers a set of additional methods to interact with torch.Tensor, redistribute the DTensor Layout to a new DTensor, get the full tensor content on all devices, etc.
|
|
2918
|
+
|
|
2919
|
+
DTensor (Distributed Tensor) is a subclass of torch.Tensor that provides single-device like abstraction to program with multi-device torch.Tensor. It describes the distributed tensor sharding layout (DTensor Layout) through the DeviceMesh and following types of Placement:
|
|
2920
|
+
|
|
2921
|
+
Shard: Tensor sharded on the tensor dimension dim on the devices of the DeviceMesh dimension
|
|
2922
|
+
|
|
2923
|
+
Replicate: Tensor replicated on the devices of the DeviceMesh dimension
|
|
2924
|
+
|
|
2925
|
+
Partial: Tensor is pending reduction on the devices of the DeviceMesh dimension
|
|
2926
|
+
|
|
2927
|
+
When calling PyTorch operators, DTensor overrides the PyTorch operators to perform sharded computation and issue communications whenever necessary. Along with the operator computation, DTensor will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate new DTensor outputs.
|
|
2928
|
+
|
|
2929
|
+
To ensure numerical correctness of the DTensor sharded computation when calling PyTorch operators, DTensor requires every Tensor argument of the operator be DTensor.
|
|
2930
|
+
|
|
2931
|
+
Directly using the Tensor subclass constructor here is not the recommended way to create a DTensor (i.e. it does not handle autograd correctly hence is not the public API). Please refer to the create_dtensor section to see how to create a DTensor.
|
|
2932
|
+
|
|
2933
|
+
Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only has one element.
|
|
2934
|
+
|
|
2935
|
+
This dunder method is primariy used for distributed checkpoint purpose.
|
|
2936
|
+
|
|
2937
|
+
A List[ChunkStorageMetadata] object that represents the shard size/offset on the current rank.
|
|
2938
|
+
|
|
2939
|
+
Create a DTensor from a local torch.Tensor on each rank according to the device_mesh and placements specified.
|
|
2940
|
+
|
|
2941
|
+
local_tensor (torch.Tensor) – local torch.Tensor on each rank.
|
|
2942
|
+
|
|
2943
|
+
device_mesh (DeviceMesh, optional) – DeviceMesh to place the tensor, if not specified, must be called under a DeviceMesh context manager, default: None
|
|
2944
|
+
|
|
2945
|
+
placements (List[Placement], optional) – the placements that describes how to place the local torch.Tensor on DeviceMesh, must have the same number of elements as device_mesh.ndim.
|
|
2946
|
+
|
|
2947
|
+
run_check (bool, optional) – at a cost of extra communications, perform sanity check across ranks to check each local tensor’s meta information to ensure correctness. If have Replicate in placements, the data on first rank of the device mesh dimension will be broadcasted to other ranks. default: False
|
|
2948
|
+
|
|
2949
|
+
shape (torch.Size, optional) – A List of int which specifies the size of DTensor which build on top of local_tensor. Note this needs to be provided if the shape of local_tensor are different across the ranks. If not provided, shape will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None
|
|
2950
|
+
|
|
2951
|
+
stride (tuple, optional) – A List of int which specifies the stride of DTensor. If not provided, stride will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None
|
|
2952
|
+
|
|
2953
|
+
When run_check=False, it is the user’s responsibility to ensure the local tensor passed in is correct across ranks (i.e. the tensor is sharded for the Shard(dim) placement or replicated for the Replicate() placement). If not, the behavior of the created DTensor is undefined.
|
|
2954
|
+
|
|
2955
|
+
from_local is differentiable, the requires_grad of the created DTensor object will depend on if local_tensor requires_grad or not.
|
|
2956
|
+
|
|
2957
|
+
Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. It’s a syntactic sugar of the following code:
|
|
2958
|
+
|
|
2959
|
+
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
|
|
2960
|
+
|
|
2961
|
+
grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function. full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor might not be used as the original replicated DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original replicated DTensor layout. If not specified, we will assume the gradient layout of the full tensor be replicated.
|
|
2962
|
+
|
|
2963
|
+
A torch.Tensor object that represents the full tensor of this DTensor.
|
|
2964
|
+
|
|
2965
|
+
full_tensor is differentiable.
|
|
2966
|
+
|
|
2967
|
+
redistribute performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from its current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh.
|
|
2968
|
+
|
|
2969
|
+
When redistributing from current to the new placements on one device mesh dimension, we will perform the following operations including communication collective or local operation:
|
|
2970
|
+
|
|
2971
|
+
Shard(dim) -> Replicate(): all_gather
|
|
2972
|
+
|
|
2973
|
+
Shard(src_dim) -> Shard(dst_dim): all_to_all
|
|
2974
|
+
|
|
2975
|
+
Replicate() -> Shard(dim): local chunking (i.e. torch.chunk)
|
|
2976
|
+
|
|
2977
|
+
Partial() -> Replicate(): all_reduce
|
|
2978
|
+
|
|
2979
|
+
Partial() -> Shard(dim): reduce_scatter
|
|
2980
|
+
|
|
2981
|
+
redistribute would correctly figure out the necessary redistribute steps for DTensors that are created either on 1-D or N-D DeviceMesh.
|
|
2982
|
+
|
|
2983
|
+
device_mesh (DeviceMesh, optional) – DeviceMesh to place the DTensor. If not specified, it would use the current DTensor’s DeviceMesh. default: None
|
|
2984
|
+
|
|
2985
|
+
placements (List[Placement], optional) – the new placements that describes how to place the DTensor into the DeviceMesh, must have the same number of elements as device_mesh.ndim. default: replicate on all mesh dimensions
|
|
2986
|
+
|
|
2987
|
+
async_op (bool, optional) – whether to perform the DTensor redistribute operation asynchronously or not. Default: False
|
|
2988
|
+
|
|
2989
|
+
forward_dtype (torch.dtype, optional) – the local tensor datatype can be converted to forward_dtype before redistributing the local tensor in its forward. The result DTensor will be in forward_dtype Default: None.
|
|
2990
|
+
|
|
2991
|
+
backward_dtype (torch.dtype, optional) – the local tensor datatype can be converted to backward_dtype before redistributing the local tensor in its backward. The result DTensor gradient would be converted back to the current DTensor dtype. Default: None
|
|
2992
|
+
|
|
2993
|
+
redistribute is differentiable, which means user do not need to worry about the backward formula of the redistribute operation.
|
|
2994
|
+
|
|
2995
|
+
redistribute currently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh.
|
|
2996
|
+
|
|
2997
|
+
Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.
|
|
2998
|
+
|
|
2999
|
+
grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function. to_local converts DTensor to local tensor and the returned local tensor might not be used as the original DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original DTensor layout. If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.
|
|
3000
|
+
|
|
3001
|
+
A torch.Tensor or AsyncCollectiveTensor object. it represents the local tensor on its current rank. When an AsyncCollectiveTensor object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to call wait to wait the local tensor to be ready.
|
|
3002
|
+
|
|
3003
|
+
to_local is differentiable, the requires_grad of the local tensor returned will depend on if the DTensor requires_grad or not.
|
|
3004
|
+
|
|
3005
|
+
The DeviceMesh attribute that associates with this DTensor object.
|
|
3006
|
+
|
|
3007
|
+
device_mesh is a read-only property, it can not be set.
|
|
3008
|
+
|
|
3009
|
+
The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.
|
|
3010
|
+
|
|
3011
|
+
placements is a read-only property, it can not be set.
|
|
3012
|
+
|
|
3013
|
+
DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and represent multi-dimensional communicators (on top of ProcessGroup). To see the details of how to create/use a DeviceMesh, please refer to the DeviceMesh recipe.
|
|
3014
|
+
|
|
3015
|
+
DTensor supports the following types of Placement on each DeviceMesh dimension:
|
|
3016
|
+
|
|
3017
|
+
The Shard(dim) placement describes the DTensor sharding on tensor dimension dim over a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. The Shard(dim) placement follows the torch.chunk(dim) semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. The Shard placement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)
|
|
3018
|
+
|
|
3019
|
+
dim (int) – The tensor dimension that describes the DTensor is sharded over its corresponding DeviceMesh dimension.
|
|
3020
|
+
|
|
3021
|
+
sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
|
|
3022
|
+
|
|
3023
|
+
The Replicate() placement describes the DTensor replicating on a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. The Replicate placement can be used by all DTensor APIs (i.e. distribute_tensor, DTensor.from_local, etc.)
|
|
3024
|
+
|
|
3025
|
+
The Partial(reduce_op) placement describes the DTensor that is pending reduction on a specified DeviceMesh dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute the Partial DTensor to a Replicate or Shard(dim) placement on the specified DeviceMesh dimension using redistribute, which would trigger necessary communication operations under the hood (i.e. allreduce, reduce_scatter).
|
|
3026
|
+
|
|
3027
|
+
reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.
|
|
3028
|
+
|
|
3029
|
+
The Partial placement can be generated as a result of the DTensor operators, and can only be used by the DTensor.from_local API.
|
|
3030
|
+
|
|
3031
|
+
The base class for the Placement type, where it describes how a DTensor is placed onto the DeviceMesh. Placement and DeviceMesh together could describe the DTensor Layout. It is the base class of the three main DTensor Placement types: Shard, Replicate, and Partial.
|
|
3032
|
+
|
|
3033
|
+
This class is not meant to be used directly, mainly served as a typing stub.
|
|
3034
|
+
|
|
3035
|
+
distribute_tensor() creates a DTensor from a logical or “global” torch.Tensor on each rank. This could be used to shard the leaf torch.Tensor s (i.e. model parameters/buffers and inputs).
|
|
3036
|
+
|
|
3037
|
+
DTensor.from_local() creates a DTensor from a local torch.Tensor on each rank, which can be used to create DTensor from a non-leaf torch.Tensor s (i.e. intermediate activation tensors during forward/backward).
|
|
3038
|
+
|
|
3039
|
+
DTensor provides dedicated tensor factory functions (e.g. empty(), ones(), randn(), etc.) to allow different DTensor creations by directly specifying the DeviceMesh and Placement. Compare to distribute_tensor(), this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.
|
|
3040
|
+
|
|
3041
|
+
The SPMD (single program, multiple data) programming model in torch.distributed launches multiple processes (i.e. via torchrun) to execute the same program, this means that the model inside the program would be initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly on GPU if enough memory).
|
|
3042
|
+
|
|
3043
|
+
DTensor offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor s, where it would create a DTensor from the “logical” Tensor on each process. This would empower the created DTensor s to comply with the single device semantic, which is critical for numerical correctness.
|
|
3044
|
+
|
|
3045
|
+
Distribute a leaf torch.Tensor (i.e. nn.Parameter/buffers) to the device_mesh according to the placements specified. The rank of device_mesh and placements must be the same. The tensor to distribute is the logical or “global” tensor, and the API would use the tensor from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd computation, please use DTensor.from_local() instead.
|
|
3046
|
+
|
|
3047
|
+
tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use torch.chunk semantic to shard the tensor and scatter the shards. The uneven sharding behavior is experimental and subject to change.
|
|
3048
|
+
|
|
3049
|
+
device_mesh (DeviceMesh, optional) – DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context manager, default: None
|
|
3050
|
+
|
|
3051
|
+
placements (List[Placement], optional) – the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements as device_mesh.ndim. If not specified, we will by default replicate the tensor across the device_mesh from the first rank of each dimension of the device_mesh.
|
|
3052
|
+
|
|
3053
|
+
src_data_rank (int, optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default, we use group_rank=0 on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None explicitly, distribute_tensor() simply uses its local data instead of trying to preserve the single-device semantic via scatter/broadcast. Default: 0
|
|
3054
|
+
|
|
3055
|
+
A DTensor or XLAShardedTensor object.
|
|
3056
|
+
|
|
3057
|
+
When initialize the DeviceMesh with the xla device_type, distribute_tensor return XLAShardedTensor instead. see this issue for more details. The XLA integration is experimental and subject to change.
|
|
3058
|
+
|
|
3059
|
+
Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier sharding on the nn.Module level
|
|
3060
|
+
|
|
3061
|
+
This function expose three functions to control the parameters/inputs/outputs of the module:
|
|
3062
|
+
|
|
3063
|
+
1. To perform sharding on the module before runtime execution by specifying the partition_fn (i.e. allow user to convert Module parameters to DTensor parameters according to the partition_fn specified). 2. To control the inputs or outputs of the module during runtime execution by specifying the input_fn and output_fn. (i.e. convert the input to DTensor, convert the output back to torch.Tensor)
|
|
3064
|
+
|
|
3065
|
+
module (nn.Module) – user module to be partitioned.
|
|
3066
|
+
|
|
3067
|
+
device_mesh (DeviceMesh) – the device mesh to place the module.
|
|
3068
|
+
|
|
3069
|
+
partition_fn (Callable) – the function to partition parameters (i.e. shard certain parameters across the device_mesh). If partition_fn is not specified, by default we replicate all module parameters of module across the mesh.
|
|
3070
|
+
|
|
3071
|
+
input_fn (Callable) – specify the input distribution, i.e. could control how the input of the module is sharded. input_fn will be installed as a module forward_pre_hook (pre forward hook).
|
|
3072
|
+
|
|
3073
|
+
output_fn (Callable) – specify the output distribution, i.e. could control how the output is sharded, or convert it back to torch.Tensor. output_fn will be installed as a module forward_hook (post forward hook).
|
|
3074
|
+
|
|
3075
|
+
A module that contains parameters/buffers that are all DTensor s.
|
|
3076
|
+
|
|
3077
|
+
When initialize the DeviceMesh with the xla device_type, distribute_module return nn.Module with PyTorch/XLA SPMD annotated parameters. See this issue for more details. The XLA integration is experimental and subject to change.
|
|
3078
|
+
|
|
3079
|
+
DTensor also provides dedicated tensor factory functions to allow creating DTensor directly using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally specifying the DeviceMesh and Placement for the DTensor created:
|
|
3080
|
+
|
|
3081
|
+
Returns a DTensor filled with the scalar value 0.
|
|
3082
|
+
|
|
3083
|
+
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
|
|
3084
|
+
|
|
3085
|
+
requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False.
|
|
3086
|
+
|
|
3087
|
+
dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()).
|
|
3088
|
+
|
|
3089
|
+
layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided.
|
|
3090
|
+
|
|
3091
|
+
device_mesh – DeviceMesh type, contains the mesh info of ranks
|
|
3092
|
+
|
|
3093
|
+
placements – a sequence of Placement type: Shard, Replicate
|
|
3094
|
+
|
|
3095
|
+
A DTensor object on each rank
|
|
3096
|
+
|
|
3097
|
+
Returns a DTensor filled with the scalar value 1, with the shape defined by the variable argument size.
|
|
3098
|
+
|
|
3099
|
+
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
|
3100
|
+
|
|
3101
|
+
dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()).
|
|
3102
|
+
|
|
3103
|
+
layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided.
|
|
3104
|
+
|
|
3105
|
+
requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False.
|
|
3106
|
+
|
|
3107
|
+
device_mesh – DeviceMesh type, contains the mesh info of ranks
|
|
3108
|
+
|
|
3109
|
+
placements – a sequence of Placement type: Shard, Replicate
|
|
3110
|
+
|
|
3111
|
+
A DTensor object on each rank
|
|
3112
|
+
|
|
3113
|
+
Returns a DTensor filled with uninitialized data. The shape of the DTensor is defined by the variable argument size.
|
|
3114
|
+
|
|
3115
|
+
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
|
|
3116
|
+
|
|
3117
|
+
dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returned DTensor. Default: torch.strided.
|
|
3118
|
+
|
|
3119
|
+
requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False.
|
|
3120
|
+
|
|
3121
|
+
device_mesh – DeviceMesh type, contains the mesh info of ranks
|
|
3122
|
+
|
|
3123
|
+
placements – a sequence of Placement type: Shard, Replicate
|
|
3124
|
+
|
|
3125
|
+
A DTensor object on each rank
|
|
3126
|
+
|
|
3127
|
+
Returns a DTensor filled with fill_value according to device_mesh and placements, with the shape defined by the argument size.
|
|
3128
|
+
|
|
3129
|
+
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
|
3130
|
+
|
|
3131
|
+
fill_value (Scalar) – the value to fill the output tensor with.
|
|
3132
|
+
|
|
3133
|
+
dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()).
|
|
3134
|
+
|
|
3135
|
+
layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided.
|
|
3136
|
+
|
|
3137
|
+
requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False.
|
|
3138
|
+
|
|
3139
|
+
device_mesh – DeviceMesh type, contains the mesh info of ranks.
|
|
3140
|
+
|
|
3141
|
+
placements – a sequence of Placement type: Shard, Replicate
|
|
3142
|
+
|
|
3143
|
+
A DTensor object on each rank
|
|
3144
|
+
|
|
3145
|
+
Returns a DTensor filled with random numbers from a uniform distribution on the interval [0, 1). The shape of the tensor is defined by the variable argument size.
|
|
3146
|
+
|
|
3147
|
+
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
|
3148
|
+
|
|
3149
|
+
dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()).
|
|
3150
|
+
|
|
3151
|
+
layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided.
|
|
3152
|
+
|
|
3153
|
+
requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False.
|
|
3154
|
+
|
|
3155
|
+
device_mesh – DeviceMesh type, contains the mesh info of ranks.
|
|
3156
|
+
|
|
3157
|
+
placements – a sequence of Placement type: Shard, Replicate
|
|
3158
|
+
|
|
3159
|
+
A DTensor object on each rank
|
|
3160
|
+
|
|
3161
|
+
Returns a DTensor filled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the tensor is defined by the variable argument size.
|
|
3162
|
+
|
|
3163
|
+
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
|
3164
|
+
|
|
3165
|
+
dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()).
|
|
3166
|
+
|
|
3167
|
+
layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided.
|
|
3168
|
+
|
|
3169
|
+
requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False.
|
|
3170
|
+
|
|
3171
|
+
device_mesh – DeviceMesh type, contains the mesh info of ranks.
|
|
3172
|
+
|
|
3173
|
+
placements – a sequence of Placement type: Shard, Replicate
|
|
3174
|
+
|
|
3175
|
+
A DTensor object on each rank
|
|
3176
|
+
|
|
3177
|
+
DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating ranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed, and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.
|
|
3178
|
+
|
|
3179
|
+
Operators that accept a generator kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.
|
|
3180
|
+
|
|
3181
|
+
When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.
|
|
3182
|
+
|
|
3183
|
+
DTensor’s RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.
|
|
3184
|
+
|
|
3185
|
+
When launching the program, you can turn on additional logging using the TORCH_LOGS environment variable from torch._logging :
|
|
3186
|
+
|
|
3187
|
+
TORCH_LOGS=+dtensor will display logging.DEBUG messages and all levels above it.
|
|
3188
|
+
|
|
3189
|
+
TORCH_LOGS=dtensor will display logging.INFO messages and above.
|
|
3190
|
+
|
|
3191
|
+
TORCH_LOGS=-dtensor will display logging.WARNING messages and above.
|
|
3192
|
+
|
|
3193
|
+
To debug the program that applied DTensor, and understand more details about what collectives happened under the hood, DTensor provides a CommDebugMode:
|
|
3194
|
+
|
|
3195
|
+
CommDebugMode is a context manager that counts the number of functional collectives within its context. It does this using a TorchDispatchMode.
|
|
3196
|
+
|
|
3197
|
+
Not all collectives are supported yet.
|
|
3198
|
+
|
|
3199
|
+
Generates detailed table displaying operations and collective tracing information on a module level. Amount of information is dependent on noise_level
|
|
3200
|
+
|
|
3201
|
+
prints module-level collective counts
|
|
3202
|
+
|
|
3203
|
+
prints dTensor operations not included in trivial operations, module information
|
|
3204
|
+
|
|
3205
|
+
prints operations not included in trivial operations
|
|
3206
|
+
|
|
3207
|
+
prints all operations
|
|
3208
|
+
|
|
3209
|
+
Creates json file used to build browser visual 0. prints module-level collective counts 1. prints dTensor operations not included in trivial operations 2. prints operations not included in trivial operations 3. prints all operations
|
|
3210
|
+
|
|
3211
|
+
Returns the communication counts as a dictionary.
|
|
3212
|
+
|
|
3213
|
+
The communication counts as a dictionary.
|
|
3214
|
+
|
|
3215
|
+
dict[str, dict[str, Any]]
|
|
3216
|
+
|
|
3217
|
+
dict[str, dict[str, Any]]
|
|
3218
|
+
|
|
3219
|
+
Alternative to console CommDebugMode output, writes to file specified by the user
|
|
3220
|
+
|
|
3221
|
+
To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides visualize_sharding():
|
|
3222
|
+
|
|
3223
|
+
Visualizes sharding in the terminal for DTensor that are 1D or 2D.
|
|
3224
|
+
|
|
3225
|
+
This requires the tabulate package, or rich and matplotlib. No sharding info will be printed for empty tensors
|
|
3226
|
+
|
|
3227
|
+
DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basic functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to these features.
|
|
3228
|
+
|
|
3229
|
+
context_parallel is an experimental API to enable context parallelism (CP). This API performs two actions: 1) patch the SDPA (torch.nn.functional.scaled_dot_product_attention) with the CP-enabled one, 2) shard buffers along the sequence dimension and each rank will preserve the corresponding shard according mesh.
|
|
3230
|
+
|
|
3231
|
+
mesh (DeviceMesh) – the device mesh for the context parallelism.
|
|
3232
|
+
|
|
3233
|
+
buffers (Optional[List[torch.Tensor]]) – buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will happen in-place, the buffer’s shape will change within the context. The buffers will be restored after the context finishes. no_restore_buffers can be used to specify which buffers don’t need to be restored. Note that buffers should not contain any nn.Parameter.
|
|
3234
|
+
|
|
3235
|
+
buffer_seq_dims (Optional[List[int]]) – the sequence dimensions of buffers.
|
|
3236
|
+
|
|
3237
|
+
no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these set won’t be restored after the context exits. This set must be a subset of buffers. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time.
|
|
3238
|
+
|
|
3239
|
+
Generator[None, None, None]
|
|
3240
|
+
|
|
3241
|
+
torch.distributed.tensor.experimental.context_parallel is a prototype feature in PyTorch. The API is subject to change.
|
|
3242
|
+
|
|
3243
|
+
local_map() is an experimental API that allows users to pass DTensor s to a function that is written to be applied on torch.Tensor s. It is done by extracting the local components of DTensor, call the function, and wrap the outputs to DTensor according to the out_placements.
|
|
3244
|
+
|
|
3245
|
+
func (Callable) – the function to be applied on each local shard of DTensor s.
|
|
3246
|
+
|
|
3247
|
+
out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of the DTensor s in func’s flattened output. If the flattened output is a single value, the out_placements should be of type PlacementType. Otherwise if the flattened output has multiple values, the out_placements should be a tuple of PlacementType values 1:1 mapping to the flattened output. Besides, for Tensor output, we use PlacementType as its placements (a Tuple[Placement] value). For non-Tensor output, the PlacementType should be None. Note that the only exception is when no DTensor argument is passed in. In this case, even if out_placements is not None, the result function should ignore the desired placements because the function is not running with DTensor s.
|
|
3248
|
+
|
|
3249
|
+
in_placements (Tuple[PlacementType, …], optional) – the required placements of the DTensor s in the flattened inputs of func. If in_placements is specified, local_map() would examine whether the placements of each DTensor argument is the same as the required placements or not. If the placements are not the same and redistribute_inputs is False, an exception will be raised. Otherwise if redistribute_inputs is True, the argument will be first redistributed to the required sharding placements before passing its local tensor to func. The only exception is when required placements are not None and the argument is a torch.Tensor. In this case, the placements examination will be skipped and the argument will be directly passed to func. If in_placements is None, no placements examination will be performed. Default: None
|
|
3250
|
+
|
|
3251
|
+
in_grad_placements (Tuple[PlacementType, …], optional) – the placements hint of the DTensor s gradient corresponds to the flattened input DTensor. This argument is the hint that user can give to to_local() in case the gradient layout of the local tensor input does not match its DTensor input layout. If not specified, we will assume the gradient layout of the local tensor input remains the same as the original DTensor input and use that for gradient computation. Default: None.
|
|
3252
|
+
|
|
3253
|
+
device_mesh (DeviceMesh, optional) – the device mesh that the output DTensor s are placed on. If not specified, this will be inferred from the first input DTensor’s device mesh. Default: None.
|
|
3254
|
+
|
|
3255
|
+
redistribute_inputs (bool, optional) – the bool value indicating whether to reshard the input DTensor s when their placements are different from the required input placements. If this value is False and some DTensor input has a different placement, an exception will be raised. Default: False.
|
|
3256
|
+
|
|
3257
|
+
A Callable that applies func to each local shard of the input DTensor and returns a DTensor constructed from the return value of func.
|
|
3258
|
+
|
|
3259
|
+
AssertionError – For any non-DTensor output, we require its corresponding output placement in out_placements be None. An AssertionError will be raised if this is not the case.
|
|
3260
|
+
|
|
3261
|
+
ValueError – If redistribute_inputs=False but the input DTensor needs a redistribution according to in_placements.
|
|
3262
|
+
|
|
3263
|
+
This API is currently experimental and subject to change
|
|
3264
|
+
|
|
3265
|
+
register_sharding() is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn’t exist a default sharding strategy for op, e.g. when op is a custom operator that is not supported by DTensor; (2) when users would like to overwrite default sharding strategies of existing operators.
|
|
3266
|
+
|
|
3267
|
+
op (Union[OpOverload, List[OpOverload]]) – An op or a list of ops to register the customized sharding function.
|
|
3268
|
+
|
|
3269
|
+
A function decorator which can be used to wrap a function that defines the sharding strategy for the operator specified in op. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is a torch.Tensor, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its corresponding input placements.
|
|
3270
|
+
|
|
3271
|
+
This API is currently experimental and subject to change
|
|
3272
|
+
|
|
3273
|
+
---
|
|
3274
|
+
|
|
3275
|
+
## FullyShardedDataParallel#
|
|
3276
|
+
|
|
3277
|
+
**URL:** https://pytorch.org/docs/stable/fsdp.html
|
|
3278
|
+
|
|
3279
|
+
**Contents:**
|
|
3280
|
+
- FullyShardedDataParallel#
|
|
3281
|
+
|
|
3282
|
+
Created On: Feb 02, 2022 | Last Updated On: Jun 11, 2025
|
|
3283
|
+
|
|
3284
|
+
A wrapper for sharding module parameters across data parallel workers.
|
|
3285
|
+
|
|
3286
|
+
This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP.
|
|
3287
|
+
|
|
3288
|
+
Using FSDP involves wrapping your module and then initializing your optimizer after. This is required since FSDP changes the parameter variables.
|
|
3289
|
+
|
|
3290
|
+
When setting up FSDP, you need to consider the destination CUDA device. If the device has an ID (dev_id), you have three options:
|
|
3291
|
+
|
|
3292
|
+
Place the module on that device
|
|
3293
|
+
|
|
3294
|
+
Set the device using torch.cuda.set_device(dev_id)
|
|
3295
|
+
|
|
3296
|
+
Pass dev_id into the device_id constructor argument.
|
|
3297
|
+
|
|
3298
|
+
This ensures that the FSDP instance’s compute device is the destination device. For option 1 and 3, the FSDP initialization always occurs on GPU. For option 2, the FSDP initialization happens on module’s current device, which may be a CPU.
|
|
3299
|
+
|
|
3300
|
+
If you’re using the sync_module_states=True flag, you need to ensure that the module is on a GPU or use the device_id argument to specify a CUDA device that FSDP will move the module to in the FSDP constructor. This is necessary because sync_module_states=True requires GPU communication.
|
|
3301
|
+
|
|
3302
|
+
FSDP also takes care of moving input tensors to the forward method to the GPU compute device, so you don’t need to manually move them from CPU.
|
|
3303
|
+
|
|
3304
|
+
For use_orig_params=True, ShardingStrategy.SHARD_GRAD_OP exposes the unsharded parameters, not the sharded parameters after forward, unlike ShardingStrategy.FULL_SHARD. If you want to inspect the gradients, you can use the summon_full_params method with with_grads=True.
|
|
3305
|
+
|
|
3306
|
+
With limit_all_gathers=True, you may see a gap in the FSDP pre-forward where the CPU thread is not issuing any kernels. This is intentional and shows the rate limiter in effect. Synchronizing the CPU thread in that way prevents over-allocating memory for subsequent all-gathers, and it should not actually delay GPU kernel execution.
|
|
3307
|
+
|
|
3308
|
+
FSDP replaces managed modules’ parameters with torch.Tensor views during forward and backward computation for autograd-related reasons. If your module’s forward relies on saved references to the parameters instead of reacquiring the references each iteration, then it will not see FSDP’s newly created views, and autograd will not work correctly.
|
|
3309
|
+
|
|
3310
|
+
Finally, when using sharding_strategy=ShardingStrategy.HYBRID_SHARD with the sharding process group being intra-node and the replication process group being inter-node, setting NCCL_CROSS_NIC=1 can help improve the all-reduce times over the replication process group for some cluster setups.
|
|
3311
|
+
|
|
3312
|
+
There are several limitations to be aware of when using FSDP:
|
|
3313
|
+
|
|
3314
|
+
FSDP currently does not support gradient accumulation outside no_sync() when using CPU offloading. This is because FSDP uses the newly-reduced gradient instead of accumulating with any existing gradient, which can lead to incorrect results.
|
|
3315
|
+
|
|
3316
|
+
FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance. This is because the submodule’s parameters will be sharded, but the submodule itself is not an FSDP instance, so its forward pass will not all-gather the full parameters appropriately.
|
|
3317
|
+
|
|
3318
|
+
FSDP does not work with double backwards due to the way it registers backward hooks.
|
|
3319
|
+
|
|
3320
|
+
FSDP has some constraints when freezing parameters. For use_orig_params=False, each FSDP instance must manage parameters that are all frozen or all non-frozen. For use_orig_params=True, FSDP supports mixing frozen and non-frozen parameters, but it’s recommended to avoid doing so to prevent higher than expected gradient memory usage.
|
|
3321
|
+
|
|
3322
|
+
As of PyTorch 1.12, FSDP offers limited support for shared parameters. If enhanced shared parameter support is needed for your use case, please post in this issue.
|
|
3323
|
+
|
|
3324
|
+
You should avoid modifying the parameters between forward and backward without using the summon_full_params context, as the modifications may not persist.
|
|
3325
|
+
|
|
3326
|
+
module (nn.Module) – This is the module to be wrapped with FSDP.
|
|
3327
|
+
|
|
3328
|
+
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – This is the process group over which the model is sharded and thus the one used for FSDP’s all-gather and reduce-scatter collective communications. If None, then FSDP uses the default process group. For hybrid sharding strategies such as ShardingStrategy.HYBRID_SHARD, users can pass in a tuple of process groups, representing the groups over which to shard and replicate, respectively. If None, then FSDP constructs process groups for the user to shard intra-node and replicate inter-node. (Default: None)
|
|
3329
|
+
|
|
3330
|
+
sharding_strategy (Optional[ShardingStrategy]) – This configures the sharding strategy, which may trade off memory saving and communication overhead. See ShardingStrategy for details. (Default: FULL_SHARD)
|
|
3331
|
+
|
|
3332
|
+
cpu_offload (Optional[CPUOffload]) – This configures CPU offloading. If this is set to None, then no CPU offloading happens. See CPUOffload for details. (Default: None)
|
|
3333
|
+
|
|
3334
|
+
auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) – This specifies a policy to apply FSDP to submodules of module, which is needed for communication and computation overlap and thus affects performance. If None, then FSDP only applies to module, and users should manually apply FSDP to parent modules themselves (proceeding bottom-up). For convenience, this accepts ModuleWrapPolicy directly, which allows users to specify the module classes to wrap (e.g. the transformer block). Otherwise, this should be a callable that takes in three arguments module: nn.Module, recurse: bool, and nonwrapped_numel: int and should return a bool specifying whether the passed-in module should have FSDP applied if recurse=False or if the traversal should continue into the module’s subtree if recurse=True. Users may add additional arguments to the callable. The size_based_auto_wrap_policy in torch.distributed.fsdp.wrap.py gives an example callable that applies FSDP to a module if the parameters in its subtree exceed 100M numel. We recommend printing the model after applying FSDP and adjusting as needed. Example: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
|
|
3335
|
+
|
|
3336
|
+
This specifies a policy to apply FSDP to submodules of module, which is needed for communication and computation overlap and thus affects performance. If None, then FSDP only applies to module, and users should manually apply FSDP to parent modules themselves (proceeding bottom-up). For convenience, this accepts ModuleWrapPolicy directly, which allows users to specify the module classes to wrap (e.g. the transformer block). Otherwise, this should be a callable that takes in three arguments module: nn.Module, recurse: bool, and nonwrapped_numel: int and should return a bool specifying whether the passed-in module should have FSDP applied if recurse=False or if the traversal should continue into the module’s subtree if recurse=True. Users may add additional arguments to the callable. The size_based_auto_wrap_policy in torch.distributed.fsdp.wrap.py gives an example callable that applies FSDP to a module if the parameters in its subtree exceed 100M numel. We recommend printing the model after applying FSDP and adjusting as needed.
|
|
3337
|
+
|
|
3338
|
+
backward_prefetch (Optional[BackwardPrefetch]) – This configures explicit backward prefetching of all-gathers. If None, then FSDP does not backward prefetch, and there is no communication and computation overlap in the backward pass. See BackwardPrefetch for details. (Default: BACKWARD_PRE)
|
|
3339
|
+
|
|
3340
|
+
mixed_precision (Optional[MixedPrecision]) – This configures native mixed precision for FSDP. If this is set to None, then no mixed precision is used. Otherwise, parameter, buffer, and gradient reduction dtypes can be set. See MixedPrecision for details. (Default: None)
|
|
3341
|
+
|
|
3342
|
+
ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whose own parameters and child modules’ parameters and buffers are ignored by this instance. None of the modules directly in ignored_modules should be FullyShardedDataParallel instances, and any child modules that are already-constructed FullyShardedDataParallel instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using an auto_wrap_policy or if parameters’ sharding is not managed by FSDP. (Default: None)
|
|
3343
|
+
|
|
3344
|
+
param_init_fn (Optional[Callable[[nn.Module], None]]) – A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies param_init_fn if specified or calls nn.Module.reset_parameters() otherwise. For both cases, the implementation should only initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX’s (pytorch/torchdistX) deferred_init() API, where the deferred modules are initialized by calling param_init_fn if specified or torchdistX’s default materialize_module() otherwise. If param_init_fn is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding. Example: >>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
|
|
3345
|
+
|
|
3346
|
+
A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies param_init_fn if specified or calls nn.Module.reset_parameters() otherwise. For both cases, the implementation should only initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX’s (pytorch/torchdistX) deferred_init() API, where the deferred modules are initialized by calling param_init_fn if specified or torchdistX’s default materialize_module() otherwise. If param_init_fn is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding.
|
|
3347
|
+
|
|
3348
|
+
device_id (Optional[Union[int, torch.device]]) – An int or torch.device giving the CUDA device on which FSDP initialization takes place, including the module initialization if needed and the parameter sharding. This should be specified to improve initialization speed if module is on CPU. If the default CUDA device was set (e.g. via torch.cuda.set_device), then the user may pass torch.cuda.current_device to this. (Default: None)
|
|
3349
|
+
|
|
3350
|
+
sync_module_states (bool) – If True, then each FSDP module will broadcast module parameters and buffers from rank 0 to ensure that they are replicated across ranks (adding communication overhead to this constructor). This can help load state_dict checkpoints via load_state_dict in a memory efficient way. See FullStateDictConfig for an example of this. (Default: False)
|
|
3351
|
+
|
|
3352
|
+
forward_prefetch (bool) – If True, then FSDP explicitly prefetches the next forward-pass all-gather before the current forward computation. This is only useful for CPU-bound workloads, in which case issuing the next all-gather earlier may improve overlap. This should only be used for static-graph models since the prefetching follows the first iteration’s execution order. (Default: False)
|
|
3353
|
+
|
|
3354
|
+
limit_all_gathers (bool) – If True, then FSDP explicitly synchronizes the CPU thread to ensure GPU memory usage from only two consecutive FSDP instances (the current instance running computation and the next instance whose all-gather is prefetched). If False, then FSDP allows the CPU thread to issue all-gathers without any extra synchronization. (Default: True) We often refer to this feature as the “rate limiter”. This flag should only be set to False for specific CPU-bound workloads with low memory pressure in which case the CPU thread can aggressively issue all kernels without concern for the GPU memory usage.
|
|
3355
|
+
|
|
3356
|
+
use_orig_params (bool) – Setting this to True has FSDP use module ‘s original parameters. FSDP exposes those original parameters to the user via nn.Module.named_parameters() instead of FSDP’s internal FlatParameter s. This means that the optimizer step runs on the original parameters, enabling per-original-parameter hyperparameters. FSDP preserves the original parameter variables and manipulates their data between unsharded and sharded forms, where they are always views into the underlying unsharded or sharded FlatParameter, respectively. With the current algorithm, the sharded form is always 1D, losing the original tensor structure. An original parameter may have all, some, or none of its data present for a given rank. In the none case, its data will be like a size-0 empty tensor. Users should not author programs relying on what data is present for a given original parameter in its sharded form. True is required to use torch.compile(). Setting this to False exposes FSDP’s internal FlatParameter s to the user via nn.Module.named_parameters(). (Default: False)
|
|
3357
|
+
|
|
3358
|
+
ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – Ignored parameters or modules that will not be managed by this FSDP instance, meaning that the parameters are not sharded and their gradients are not reduced across ranks. This argument unifies with the existing ignored_modules argument, and we may deprecate ignored_modules soon. For backward compatibility, we keep both ignored_states and ignored_modules`, but FSDP only allows one of them to be specified as not None.
|
|
3359
|
+
|
|
3360
|
+
device_mesh (Optional[DeviceMesh]) – DeviceMesh can be used as an alternative to process_group. When device_mesh is passed, FSDP will use the underlying process groups for all-gather and reduce-scatter collective communications. Therefore, these two args need to be mutually exclusive. For hybrid sharding strategies such as ShardingStrategy.HYBRID_SHARD, users can pass in a 2D DeviceMesh instead of a tuple of process groups. For 2D FSDP + TP, users are required to pass in device_mesh instead of process_group. For more DeviceMesh info, please visit: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
|
|
3361
|
+
|
|
3362
|
+
Apply fn recursively to every submodule (as returned by .children()) as well as self.
|
|
3363
|
+
|
|
3364
|
+
Typical use includes initializing the parameters of a model (see also torch.nn.init).
|
|
3365
|
+
|
|
3366
|
+
Compared to torch.nn.Module.apply, this version additionally gathers the full parameters before applying fn. It should not be called from within another summon_full_params context.
|
|
3367
|
+
|
|
3368
|
+
fn (Module -> None) – function to be applied to each submodule
|
|
3369
|
+
|
|
3370
|
+
Check if this instance is a root FSDP module.
|
|
3371
|
+
|
|
3372
|
+
Clip the gradient norm of all parameters.
|
|
3373
|
+
|
|
3374
|
+
The norm is computed over all parameters’ gradients as viewed as a single vector, and the gradients are modified in-place.
|
|
3375
|
+
|
|
3376
|
+
max_norm (float or int) – max norm of the gradients
|
|
3377
|
+
|
|
3378
|
+
norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.
|
|
3379
|
+
|
|
3380
|
+
Total norm of the parameters (viewed as a single vector).
|
|
3381
|
+
|
|
3382
|
+
If every FSDP instance uses NO_SHARD, meaning that no gradients are sharded across ranks, then you may directly use torch.nn.utils.clip_grad_norm_().
|
|
3383
|
+
|
|
3384
|
+
If at least some FSDP instance uses a sharded strategy (i.e. one other than NO_SHARD), then you should use this method instead of torch.nn.utils.clip_grad_norm_() since this method handles the fact that gradients are sharded across ranks.
|
|
3385
|
+
|
|
3386
|
+
The total norm returned will have the “largest” dtype across all parameters/gradients as defined by PyTorch’s type promotion semantics. For example, if all parameters/gradients use a low precision dtype, then the returned norm’s dtype will be that low precision dtype, but if there exists at least one parameter/ gradient using FP32, then the returned norm’s dtype will be FP32.
|
|
3387
|
+
|
|
3388
|
+
This needs to be called on all ranks since it uses collective communications.
|
|
3389
|
+
|
|
3390
|
+
Flatten a sharded optimizer state-dict.
|
|
3391
|
+
|
|
3392
|
+
The API is similar to shard_full_optim_state_dict(). The only difference is that the input sharded_optim_state_dict should be returned from sharded_optim_state_dict(). Therefore, there will be all-gather calls on each rank to gather ShardedTensor s.
|
|
3393
|
+
|
|
3394
|
+
sharded_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state.
|
|
3395
|
+
|
|
3396
|
+
model (torch.nn.Module) – Refer to shard_full_optim_state_dict().
|
|
3397
|
+
|
|
3398
|
+
optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters.
|
|
3399
|
+
|
|
3400
|
+
Refer to shard_full_optim_state_dict().
|
|
3401
|
+
|
|
3402
|
+
Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.
|
|
3403
|
+
|
|
3404
|
+
Return all nested FSDP instances.
|
|
3405
|
+
|
|
3406
|
+
This possibly includes module itself and only includes FSDP root modules if root_only=True.
|
|
3407
|
+
|
|
3408
|
+
module (torch.nn.Module) – Root module, which may or may not be an FSDP module.
|
|
3409
|
+
|
|
3410
|
+
root_only (bool) – Whether to return only FSDP root modules. (Default: False)
|
|
3411
|
+
|
|
3412
|
+
FSDP modules that are nested in the input module.
|
|
3413
|
+
|
|
3414
|
+
List[FullyShardedDataParallel]
|
|
3415
|
+
|
|
3416
|
+
Return the full optimizer state-dict.
|
|
3417
|
+
|
|
3418
|
+
Consolidates the full optimizer state on rank 0 and returns it as a dict following the convention of torch.optim.Optimizer.state_dict(), i.e. with keys "state" and "param_groups". The flattened parameters in FSDP modules contained in model are mapped back to their unflattened parameters.
|
|
3419
|
+
|
|
3420
|
+
This needs to be called on all ranks since it uses collective communications. However, if rank0_only=True, then the state dict is only populated on rank 0, and all other ranks return an empty dict.
|
|
3421
|
+
|
|
3422
|
+
Unlike torch.optim.Optimizer.state_dict(), this method uses full parameter names as keys instead of parameter IDs.
|
|
3423
|
+
|
|
3424
|
+
Like in torch.optim.Optimizer.state_dict(), the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. using torch.save().
|
|
3425
|
+
|
|
3426
|
+
model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim.
|
|
3427
|
+
|
|
3428
|
+
optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters.
|
|
3429
|
+
|
|
3430
|
+
optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer optim representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). This argument is deprecated, and there is no need to pass it in anymore. (Default: None)
|
|
3431
|
+
|
|
3432
|
+
rank0_only (bool) – If True, saves the populated dict only on rank 0; if False, saves it on all ranks. (Default: True)
|
|
3433
|
+
|
|
3434
|
+
group (dist.ProcessGroup) – Model’s process group or None if using the default process group. (Default: None)
|
|
3435
|
+
|
|
3436
|
+
A dict containing the optimizer state for model ‘s original unflattened parameters and including keys “state” and “param_groups” following the convention of torch.optim.Optimizer.state_dict(). If rank0_only=True, then nonzero ranks return an empty dict.
|
|
3437
|
+
|
|
3438
|
+
Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at module.
|
|
3439
|
+
|
|
3440
|
+
The target module does not have to be an FSDP module.
|
|
3441
|
+
|
|
3442
|
+
A StateDictSettings containing the state_dict_type and state_dict / optim_state_dict configs that are currently set.
|
|
3443
|
+
|
|
3444
|
+
AssertionError` if the StateDictSettings for differen –
|
|
3445
|
+
|
|
3446
|
+
FSDP submodules differ. –
|
|
3447
|
+
|
|
3448
|
+
Return the wrapped module.
|
|
3449
|
+
|
|
3450
|
+
Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
|
|
3451
|
+
|
|
3452
|
+
Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix when inside the summon_full_params() context manager.
|
|
3453
|
+
|
|
3454
|
+
Iterator[tuple[str, torch.Tensor]]
|
|
3455
|
+
|
|
3456
|
+
Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
|
|
3457
|
+
|
|
3458
|
+
Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix when inside the summon_full_params() context manager.
|
|
3459
|
+
|
|
3460
|
+
Iterator[tuple[str, torch.nn.parameter.Parameter]]
|
|
3461
|
+
|
|
3462
|
+
Disable gradient synchronizations across FSDP instances.
|
|
3463
|
+
|
|
3464
|
+
Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.
|
|
3465
|
+
|
|
3466
|
+
This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.
|
|
3467
|
+
|
|
3468
|
+
When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync.
|
|
3469
|
+
|
|
3470
|
+
Transform the state-dict of an optimizer corresponding to a sharded model.
|
|
3471
|
+
|
|
3472
|
+
The given state-dict can be transformed to one of three types: 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
|
|
3473
|
+
|
|
3474
|
+
For full optimizer state_dict, all states are unflattened and not sharded. Rank0 only and CPU only can be specified via state_dict_type() to avoid OOM.
|
|
3475
|
+
|
|
3476
|
+
For sharded optimizer state_dict, all states are unflattened but sharded. CPU only can be specified via state_dict_type() to further save memory.
|
|
3477
|
+
|
|
3478
|
+
For local state_dict, no transformation will be performed. But a state will be converted from nn.Tensor to ShardedTensor to represent its sharding nature (this is not supported yet).
|
|
3479
|
+
|
|
3480
|
+
model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim.
|
|
3481
|
+
|
|
3482
|
+
optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters.
|
|
3483
|
+
|
|
3484
|
+
optim_state_dict (Dict[str, Any]) – the target optimizer state_dict to transform. If the value is None, optim.state_dict() will be used. ( Default: None)
|
|
3485
|
+
|
|
3486
|
+
group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or None if using the default process group. ( Default: None)
|
|
3487
|
+
|
|
3488
|
+
A dict containing the optimizer state for model. The sharding of the optimizer state is based on state_dict_type.
|
|
3489
|
+
|
|
3490
|
+
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
|
3491
|
+
|
|
3492
|
+
Given a optim_state_dict that is transformed through optim_state_dict(), it gets converted to the flattened optimizer state_dict that can be loaded to optim which is the optimizer for model. model must be sharded by FullyShardedDataParallel.
|
|
3493
|
+
|
|
3494
|
+
model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim.
|
|
3495
|
+
|
|
3496
|
+
optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters.
|
|
3497
|
+
|
|
3498
|
+
optim_state_dict (Dict[str, Any]) – The optimizer states to be loaded.
|
|
3499
|
+
|
|
3500
|
+
is_named_optimizer (bool) – Is this optimizer a NamedOptimizer or KeyedOptimizer. Only set to True if optim is TorchRec’s KeyedOptimizer or torch.distributed’s NamedOptimizer.
|
|
3501
|
+
|
|
3502
|
+
load_directly (bool) – If this is set to True, this API will also call optim.load_state_dict(result) before returning the result. Otherwise, users are responsible to call optim.load_state_dict() (Default: False)
|
|
3503
|
+
|
|
3504
|
+
group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or None if using the default process group. ( Default: None)
|
|
3505
|
+
|
|
3506
|
+
Register a communication hook.
|
|
3507
|
+
|
|
3508
|
+
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while training with FullyShardedDataParallel.
|
|
3509
|
+
|
|
3510
|
+
FSDP communication hook should be registered before running an initial forward pass and only once.
|
|
3511
|
+
|
|
3512
|
+
state (object) – Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker.
|
|
3513
|
+
|
|
3514
|
+
Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker.
|
|
3515
|
+
|
|
3516
|
+
hook (Callable) – Callable, which has one of the following signatures: 1) hook: Callable[torch.Tensor] -> None: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returns None; 2) hook: Callable[torch.Tensor, torch.Tensor] -> None: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returns None. Callables with signature 1 are expected to handle gradient communication for a NO_SHARD case. Callables with signature 2 are expected to handle gradient communication for sharded cases.
|
|
3517
|
+
|
|
3518
|
+
Re-keys the optimizer state dict optim_state_dict to use the key type optim_state_key_type.
|
|
3519
|
+
|
|
3520
|
+
This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without.
|
|
3521
|
+
|
|
3522
|
+
To re-key an FSDP full optimizer state dict (i.e. from full_optim_state_dict()) to use parameter IDs and be loadable to a non-wrapped model:
|
|
3523
|
+
|
|
3524
|
+
To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:
|
|
3525
|
+
|
|
3526
|
+
The optimizer state dict re-keyed using the parameter keys specified by optim_state_key_type.
|
|
3527
|
+
|
|
3528
|
+
Scatter the full optimizer state dict from rank 0 to all other ranks.
|
|
3529
|
+
|
|
3530
|
+
Returns the sharded optimizer state dict on each rank. The return value is the same as shard_full_optim_state_dict(), and on rank 0, the first argument should be the return value of full_optim_state_dict().
|
|
3531
|
+
|
|
3532
|
+
Both shard_full_optim_state_dict() and scatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.
|
|
3533
|
+
|
|
3534
|
+
full_optim_state_dict (Optional[Dict[str, Any]]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks.
|
|
3535
|
+
|
|
3536
|
+
model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters correspond to the optimizer state in full_optim_state_dict.
|
|
3537
|
+
|
|
3538
|
+
optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). This argument is deprecated, and there is no need to pass it in anymore. (Default: None)
|
|
3539
|
+
|
|
3540
|
+
optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over optim_input. (Default: None)
|
|
3541
|
+
|
|
3542
|
+
group (dist.ProcessGroup) – Model’s process group or None if using the default process group. (Default: None)
|
|
3543
|
+
|
|
3544
|
+
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
|
|
3545
|
+
|
|
3546
|
+
Set the state_dict_type of all the descendant FSDP modules of the target module.
|
|
3547
|
+
|
|
3548
|
+
Also takes (optional) configuration for the model’s and optimizer’s state dict. The target module does not have to be a FSDP module. If the target module is a FSDP module, its state_dict_type will also be changed.
|
|
3549
|
+
|
|
3550
|
+
This API should be called for only the top-level (root) module.
|
|
3551
|
+
|
|
3552
|
+
This API enables users to transparently use the conventional state_dict API to take model checkpoints in cases where the root FSDP module is wrapped by another nn.Module. For example, the following will ensure state_dict is called on all non-FSDP instances, while dispatching into sharded_state_dict implementation for FSDP:
|
|
3553
|
+
|
|
3554
|
+
module (torch.nn.Module) – Root module.
|
|
3555
|
+
|
|
3556
|
+
state_dict_type (StateDictType) – the desired state_dict_type to set.
|
|
3557
|
+
|
|
3558
|
+
state_dict_config (Optional[StateDictConfig]) – the configuration for the target state_dict_type.
|
|
3559
|
+
|
|
3560
|
+
optim_state_dict_config (Optional[OptimStateDictConfig]) – the configuration for the optimizer state dict.
|
|
3561
|
+
|
|
3562
|
+
A StateDictSettings that include the previous state_dict type and configuration for the module.
|
|
3563
|
+
|
|
3564
|
+
Shard a full optimizer state-dict.
|
|
3565
|
+
|
|
3566
|
+
Remaps the state in full_optim_state_dict to flattened parameters instead of unflattened parameters and restricts to only this rank’s part of the optimizer state. The first argument should be the return value of full_optim_state_dict().
|
|
3567
|
+
|
|
3568
|
+
Both shard_full_optim_state_dict() and scatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.
|
|
3569
|
+
|
|
3570
|
+
full_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state.
|
|
3571
|
+
|
|
3572
|
+
model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters correspond to the optimizer state in full_optim_state_dict.
|
|
3573
|
+
|
|
3574
|
+
optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). This argument is deprecated, and there is no need to pass it in anymore. (Default: None)
|
|
3575
|
+
|
|
3576
|
+
optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over optim_input. (Default: None)
|
|
3577
|
+
|
|
3578
|
+
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
|
|
3579
|
+
|
|
3580
|
+
Return the optimizer state-dict in its sharded form.
|
|
3581
|
+
|
|
3582
|
+
The API is similar to full_optim_state_dict() but this API chunks all non-zero-dimension states to ShardedTensor to save memory. This API should only be used when the model state_dict is derived with the context manager with state_dict_type(SHARDED_STATE_DICT):.
|
|
3583
|
+
|
|
3584
|
+
For the detailed usage, refer to full_optim_state_dict().
|
|
3585
|
+
|
|
3586
|
+
The returned state dict contains ShardedTensor and cannot be directly used by the regular optim.load_state_dict.
|
|
3587
|
+
|
|
3588
|
+
Set the state_dict_type of all the descendant FSDP modules of the target module.
|
|
3589
|
+
|
|
3590
|
+
This context manager has the same functions as set_state_dict_type(). Read the document of set_state_dict_type() for the detail.
|
|
3591
|
+
|
|
3592
|
+
module (torch.nn.Module) – Root module.
|
|
3593
|
+
|
|
3594
|
+
state_dict_type (StateDictType) – the desired state_dict_type to set.
|
|
3595
|
+
|
|
3596
|
+
state_dict_config (Optional[StateDictConfig]) – the model state_dict configuration for the target state_dict_type.
|
|
3597
|
+
|
|
3598
|
+
optim_state_dict_config (Optional[OptimStateDictConfig]) – the optimizer state_dict configuration for the target state_dict_type.
|
|
3599
|
+
|
|
3600
|
+
Expose full params for FSDP instances with this context manager.
|
|
3601
|
+
|
|
3602
|
+
Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the recurse argument.
|
|
3603
|
+
|
|
3604
|
+
This can be used on inner FSDPs.
|
|
3605
|
+
|
|
3606
|
+
This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.
|
|
3607
|
+
|
|
3608
|
+
Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward.
|
|
3609
|
+
|
|
3610
|
+
The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless writeback=False, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only when world_size == 1, or NO_SHARD config, the modification is persisted regardless of writeback.
|
|
3611
|
+
|
|
3612
|
+
This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units.
|
|
3613
|
+
|
|
3614
|
+
Note that rank0_only=True in conjunction with writeback=True is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.
|
|
3615
|
+
|
|
3616
|
+
Note that offload_to_cpu and rank0_only=False will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to use offload_to_cpu with rank0_only=True.
|
|
3617
|
+
|
|
3618
|
+
recurse (bool, Optional) – recursively summon all params for nested FSDP instances (default: True).
|
|
3619
|
+
|
|
3620
|
+
writeback (bool, Optional) – if False, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True)
|
|
3621
|
+
|
|
3622
|
+
rank0_only (bool, Optional) – if True, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that setting rank0_only=True with writeback=True is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.
|
|
3623
|
+
|
|
3624
|
+
offload_to_cpu (bool, Optional) – If True, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 or NO_SHARD config). It is recommended to use offload_to_cpu with rank0_only=True to avoid redundant copies of model parameters being offloaded to the same CPU memory.
|
|
3625
|
+
|
|
3626
|
+
with_grads (bool, Optional) – If True, gradients are also unsharded with the parameters. Currently, this is only supported when passing use_orig_params=True to the FSDP constructor and offload_to_cpu=False to this method. (Default: False)
|
|
3627
|
+
|
|
3628
|
+
This configures explicit backward prefetching, which improves throughput by enabling communication and computation overlap in the backward pass at the cost of slightly increased memory usage.
|
|
3629
|
+
|
|
3630
|
+
BACKWARD_PRE: This enables the most overlap but increases memory usage the most. This prefetches the next set of parameters before the current set of parameters’ gradient computation. This overlaps the next all-gather and the current gradient computation, and at the peak, it holds the current set of parameters, next set of parameters, and current set of gradients in memory.
|
|
3631
|
+
|
|
3632
|
+
BACKWARD_POST: This enables less overlap but requires less memory usage. This prefetches the next set of parameters after the current set of parameters’ gradient computation. This overlaps the current reduce-scatter and the next gradient computation, and it frees the current set of parameters before allocating memory for the next set of parameters, only holding the next set of parameters and current set of gradients in memory at the peak.
|
|
3633
|
+
|
|
3634
|
+
FSDP’s backward_prefetch argument accepts None, which disables the backward prefetching altogether. This has no overlap and does not increase memory usage. In general, we do not recommend this setting since it may degrade throughput significantly.
|
|
3635
|
+
|
|
3636
|
+
For more technical context: For a single process group using NCCL backend, any collectives, even if issued from different streams, contend for the same per-device NCCL stream, which implies that the relative order in which the collectives are issued matters for overlapping. The two backward prefetching values correspond to different issue orders.
|
|
3637
|
+
|
|
3638
|
+
This specifies the sharding strategy to be used for distributed training by FullyShardedDataParallel.
|
|
3639
|
+
|
|
3640
|
+
FULL_SHARD: Parameters, gradients, and optimizer states are sharded. For the parameters, this strategy unshards (via all-gather) before the forward, reshards after the forward, unshards before the backward computation, and reshards after the backward computation. For gradients, it synchronizes and shards them (via reduce-scatter) after the backward computation. The sharded optimizer states are updated locally per rank.
|
|
3641
|
+
|
|
3642
|
+
SHARD_GRAD_OP: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. For the parameters, this strategy unshards before the forward, does not reshard them after the forward, and only reshards them after the backward computation. The sharded optimizer states are updated locally per rank. Inside no_sync(), the parameters are not resharded after the backward computation.
|
|
3643
|
+
|
|
3644
|
+
NO_SHARD: Parameters, gradients, and optimizer states are not sharded but instead replicated across ranks similar to PyTorch’s DistributedDataParallel API. For gradients, this strategy synchronizes them (via all-reduce) after the backward computation. The unsharded optimizer states are updated locally per rank.
|
|
3645
|
+
|
|
3646
|
+
HYBRID_SHARD: Apply FULL_SHARD within a node, and replicate parameters across nodes. This results in reduced communication volume as expensive all-gathers and reduce-scatters are only done within a node, which can be more performant for medium -sized models.
|
|
3647
|
+
|
|
3648
|
+
_HYBRID_SHARD_ZERO2: Apply SHARD_GRAD_OP within a node, and replicate parameters across nodes. This is like HYBRID_SHARD, except this may provide even higher throughput since the unsharded parameters are not freed after the forward pass, saving the all-gathers in the pre-backward.
|
|
3649
|
+
|
|
3650
|
+
This configures FSDP-native mixed precision training.
|
|
3651
|
+
|
|
3652
|
+
param_dtype (Optional[torch.dtype]) – This specifies the dtype for model parameters during forward and backward and thus the dtype for forward and backward computation. Outside forward and backward, the sharded parameters are kept in full precision (e.g. for the optimizer step), and for model checkpointing, the parameters are always saved in full precision. (Default: None)
|
|
3653
|
+
|
|
3654
|
+
reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is None but param_dtype is not None, then this takes on the param_dtype value, still running gradient reduction in low precision. This is permitted to differ from param_dtype, e.g. to force gradient reduction to run in full precision. (Default: None)
|
|
3655
|
+
|
|
3656
|
+
buffer_dtype (Optional[torch.dtype]) – This specifies the dtype for buffers. FSDP does not shard buffers. Rather, FSDP casts them to buffer_dtype in the first forward pass and keeps them in that dtype thereafter. For model checkpointing, the buffers are saved in full precision except for LOCAL_STATE_DICT. (Default: None)
|
|
3657
|
+
|
|
3658
|
+
keep_low_precision_grads (bool) – If False, then FSDP upcasts gradients to full precision after the backward pass in preparation for the optimizer step. If True, then FSDP keeps the gradients in the dtype used for gradient reduction, which can save memory if using a custom optimizer that supports running in low precision. (Default: False)
|
|
3659
|
+
|
|
3660
|
+
cast_forward_inputs (bool) – If True, then this FSDP module casts its forward args and kwargs to param_dtype. This is to ensure that parameter and input dtypes match for forward computation, as required by many ops. This may need to be set to True when only applying mixed precision to some but not all FSDP modules, in which case a mixed-precision FSDP submodule needs to recast its inputs. (Default: False)
|
|
3661
|
+
|
|
3662
|
+
cast_root_forward_inputs (bool) – If True, then the root FSDP module casts its forward args and kwargs to param_dtype, overriding the value of cast_forward_inputs. For non-root FSDP modules, this does not do anything. (Default: True)
|
|
3663
|
+
|
|
3664
|
+
_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): This specifies module classes to ignore for mixed precision when using an auto_wrap_policy: Modules of these classes will have FSDP applied to them separately with mixed precision disabled (meaning that the final FSDP construction would deviate from the specified policy). If auto_wrap_policy is not specified, then this does not do anything. This API is experimental and subject to change. (Default: (_BatchNorm,))
|
|
3665
|
+
|
|
3666
|
+
This API is experimental and subject to change.
|
|
3667
|
+
|
|
3668
|
+
Only floating point tensors are cast to their specified dtypes.
|
|
3669
|
+
|
|
3670
|
+
In summon_full_params, parameters are forced to full precision, but buffers are not.
|
|
3671
|
+
|
|
3672
|
+
Layer norm and batch norm accumulate in float32 even when their inputs are in a low precision like float16 or bfloat16. Disabling FSDP’s mixed precision for those norm modules only means that the affine parameters are kept in float32. However, this incurs separate all-gathers and reduce-scatters for those norm modules, which may be inefficient, so if the workload permits, the user should prefer to still apply mixed precision to those modules.
|
|
3673
|
+
|
|
3674
|
+
By default, if the user passes a model with any _BatchNorm modules and specifies an auto_wrap_policy, then the batch norm modules will have FSDP applied to them separately with mixed precision disabled. See the _module_classes_to_ignore argument.
|
|
3675
|
+
|
|
3676
|
+
MixedPrecision has cast_root_forward_inputs=True and cast_forward_inputs=False by default. For the root FSDP instance, its cast_root_forward_inputs takes precedence over its cast_forward_inputs. For non-root FSDP instances, their cast_root_forward_inputs values are ignored. The default setting is sufficient for the typical case where each FSDP instance has the same MixedPrecision configuration and only needs to cast inputs to the param_dtype at the beginning of the model’s forward pass.
|
|
3677
|
+
|
|
3678
|
+
For nested FSDP instances with different MixedPrecision configurations, we recommend setting individual cast_forward_inputs values to configure casting inputs or not before each instance’s forward. In such a case, since the casts happen before each FSDP instance’s forward, a parent FSDP instance should have its non-FSDP submodules run before its FSDP submodules to avoid the activation dtype being changed due to a different MixedPrecision configuration.
|
|
3679
|
+
|
|
3680
|
+
The above shows a working example. On the other hand, if model[1] were replaced with model[0], meaning that the submodule using different MixedPrecision ran its forward first, then model[1] would incorrectly see float16 activations instead of bfloat16 ones.
|
|
3681
|
+
|
|
3682
|
+
This configures CPU offloading.
|
|
3683
|
+
|
|
3684
|
+
offload_params (bool) – This specifies whether to offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.
|
|
3685
|
+
|
|
3686
|
+
StateDictConfig is the base class for all state_dict configuration classes. Users should instantiate a child class (e.g. FullStateDictConfig) in order to configure settings for the corresponding state_dict type supported by FSDP.
|
|
3687
|
+
|
|
3688
|
+
offload_to_cpu (bool) – If True, then FSDP offloads the state dict values to CPU, and if False, then FSDP keeps them on GPU. (Default: False)
|
|
3689
|
+
|
|
3690
|
+
FullStateDictConfig is a config class meant to be used with StateDictType.FULL_STATE_DICT. We recommend enabling both offload_to_cpu=True and rank0_only=True when saving full state dicts to save GPU memory and CPU memory, respectively. This config class is meant to be used via the state_dict_type() context manager as follows:
|
|
3691
|
+
|
|
3692
|
+
rank0_only (bool) – If True, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. If False, then all ranks save the full state dict. (Default: False)
|
|
3693
|
+
|
|
3694
|
+
ShardedStateDictConfig is a config class meant to be used with StateDictType.SHARDED_STATE_DICT.
|
|
3695
|
+
|
|
3696
|
+
_use_dtensor (bool) – If True, then FSDP saves the state dict values as DTensor, and if False, then FSDP saves them as ShardedTensor. (Default: False)
|
|
3697
|
+
|
|
3698
|
+
_use_dtensor is a private field of ShardedStateDictConfig and it is used by FSDP to determine the type of state dict values. Users should not manually modify _use_dtensor.
|
|
3699
|
+
|
|
3700
|
+
OptimStateDictConfig is the base class for all optim_state_dict configuration classes. Users should instantiate a child class (e.g. FullOptimStateDictConfig) in order to configure settings for the corresponding optim_state_dict type supported by FSDP.
|
|
3701
|
+
|
|
3702
|
+
offload_to_cpu (bool) – If True, then FSDP offloads the state dict’s tensor values to CPU, and if False, then FSDP keeps them on the original device (which is GPU unless parameter CPU offloading is enabled). (Default: True)
|
|
3703
|
+
|
|
3704
|
+
rank0_only (bool) – If True, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. If False, then all ranks save the full state dict. (Default: False)
|
|
3705
|
+
|
|
3706
|
+
ShardedOptimStateDictConfig is a config class meant to be used with StateDictType.SHARDED_STATE_DICT.
|
|
3707
|
+
|
|
3708
|
+
_use_dtensor (bool) – If True, then FSDP saves the state dict values as DTensor, and if False, then FSDP saves them as ShardedTensor. (Default: False)
|
|
3709
|
+
|
|
3710
|
+
_use_dtensor is a private field of ShardedOptimStateDictConfig and it is used by FSDP to determine the type of state dict values. Users should not manually modify _use_dtensor.
|
|
3711
|
+
|
|
3712
|
+
---
|
|
3713
|
+
|
|
3714
|
+
## Distributed Optimizers#
|
|
3715
|
+
|
|
3716
|
+
**URL:** https://pytorch.org/docs/stable/distributed.optim.html
|
|
3717
|
+
|
|
3718
|
+
**Contents:**
|
|
3719
|
+
- Distributed Optimizers#
|
|
3720
|
+
|
|
3721
|
+
Created On: Mar 01, 2021 | Last Updated On: Jun 16, 2025
|
|
3722
|
+
|
|
3723
|
+
Distributed optimizer is not currently supported when using CUDA tensors
|
|
3724
|
+
|
|
3725
|
+
torch.distributed.optim exposes DistributedOptimizer, which takes a list of remote parameters (RRef) and runs the optimizer locally on the workers where the parameters live. The distributed optimizer can use any of the local optimizer Base class to apply the gradients on each worker.
|
|
3726
|
+
|
|
3727
|
+
DistributedOptimizer takes remote references to parameters scattered across workers and applies the given optimizer locally for each parameter.
|
|
3728
|
+
|
|
3729
|
+
This class uses get_gradients() in order to retrieve the gradients for specific parameters.
|
|
3730
|
+
|
|
3731
|
+
Concurrent calls to step(), either from the same or different clients, will be serialized on each worker – as each worker’s optimizer can only work on one set of gradients at a time. However, there is no guarantee that the full forward-backward-optimizer sequence will execute for one client at a time. This means that the gradients being applied may not correspond to the latest forward pass executed on a given worker. Also, there is no guaranteed ordering across workers.
|
|
3732
|
+
|
|
3733
|
+
DistributedOptimizer creates the local optimizer with TorchScript enabled by default, so that optimizer updates are not blocked by the Python Global Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed Model Parallel). This feature is currently enabled for most optimizers. You can also follow the recipe in PyTorch tutorials to enable TorchScript support for your own custom optimizers.
|
|
3734
|
+
|
|
3735
|
+
optimizer_class (optim.Optimizer) – the class of optimizer to instantiate on each worker.
|
|
3736
|
+
|
|
3737
|
+
params_rref (list[RRef]) – list of RRefs to local or remote parameters to optimize.
|
|
3738
|
+
|
|
3739
|
+
args – arguments to pass to the optimizer constructor on each worker.
|
|
3740
|
+
|
|
3741
|
+
kwargs – arguments to pass to the optimizer constructor on each worker.
|
|
3742
|
+
|
|
3743
|
+
Performs a single optimization step.
|
|
3744
|
+
|
|
3745
|
+
This will call torch.optim.Optimizer.step() on each worker containing parameters to be optimized, and will block until all workers return. The provided context_id will be used to retrieve the corresponding context that contains the gradients that should be applied to the parameters.
|
|
3746
|
+
|
|
3747
|
+
context_id – the autograd context id for which we should run the optimizer step.
|
|
3748
|
+
|
|
3749
|
+
Wraps an arbitrary torch.optim.Optimizer and runs post-local SGD, This optimizer runs local optimizer at every step. After the warm-up stage, it averages parameters periodically after the local optimizer is applied.
|
|
3750
|
+
|
|
3751
|
+
optim (Optimizer) – The local optimizer.
|
|
3752
|
+
|
|
3753
|
+
averager (ModelAverager) – A model averager instance to run post-localSGD algorithm.
|
|
3754
|
+
|
|
3755
|
+
This is the same as torch.optim.Optimizer load_state_dict(), but also restores model averager’s step value to the one saved in the provided state_dict.
|
|
3756
|
+
|
|
3757
|
+
If there is no "step" entry in state_dict, it will raise a warning and initialize the model averager’s step to 0.
|
|
3758
|
+
|
|
3759
|
+
This is the same as torch.optim.Optimizer state_dict(), but adds an extra entry to record model averager’s step to the checkpoint to ensure reload does not cause unnecessary warm up again.
|
|
3760
|
+
|
|
3761
|
+
Performs a single optimization step (parameter update).
|
|
3762
|
+
|
|
3763
|
+
Wrap an arbitrary optim.Optimizer and shards its states across ranks in the group.
|
|
3764
|
+
|
|
3765
|
+
The sharing is done as described by ZeRO.
|
|
3766
|
+
|
|
3767
|
+
The local optimizer instance in each rank is only responsible for updating approximately 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataParallel to reduce per-rank peak memory consumption.
|
|
3768
|
+
|
|
3769
|
+
ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.
|
|
3770
|
+
|
|
3771
|
+
params (Iterable) – an Iterable of torch.Tensor s or dict s giving all parameters, which will be sharded across ranks.
|
|
3772
|
+
|
|
3773
|
+
optimizer_class (torch.nn.Optimizer) – the class of the local optimizer.
|
|
3774
|
+
|
|
3775
|
+
process_group (ProcessGroup, optional) – torch.distributed ProcessGroup (default: dist.group.WORLD initialized by torch.distributed.init_process_group()).
|
|
3776
|
+
|
|
3777
|
+
parameters_as_bucket_view (bool, optional) – if True, parameters are packed into buckets to speed up communication, and param.data fields point to bucket views at different offsets; if False, each individual parameter is communicated separately, and each params.data stays intact (default: False).
|
|
3778
|
+
|
|
3779
|
+
overlap_with_ddp (bool, optional) – if True, step() is overlapped with DistributedDataParallel ‘s gradient synchronization; this requires (1) either a functional optimizer for the optimizer_class argument or one with a functional equivalent and (2) registering a DDP communication hook constructed from one of the functions in ddp_zero_hook.py; parameters are packed into buckets matching those in DistributedDataParallel, meaning that the parameters_as_bucket_view argument is ignored. If False, step() runs disjointly after the backward pass (per normal). (default: False)
|
|
3780
|
+
|
|
3781
|
+
**defaults – any trailing arguments, which are forwarded to the local optimizer.
|
|
3782
|
+
|
|
3783
|
+
Currently, ZeroRedundancyOptimizer requires that all of the passed-in parameters are the same dense type.
|
|
3784
|
+
|
|
3785
|
+
If you pass overlap_with_ddp=True, be wary of the following: Given the way that overlapping DistributedDataParallel with ZeroRedundancyOptimizer is currently implemented, the first two or three training iterations do not perform parameter updates in the optimizer step, depending on if static_graph=False or static_graph=True, respectively. This is because it needs information about the gradient bucketing strategy used by DistributedDataParallel, which is not finalized until the second forward pass if static_graph=False or until the third forward pass if static_graph=True. To adjust for this, one option is to prepend dummy inputs.
|
|
3786
|
+
|
|
3787
|
+
ZeroRedundancyOptimizer is experimental and subject to change.
|
|
3788
|
+
|
|
3789
|
+
Add a parameter group to the Optimizer ‘s param_groups.
|
|
3790
|
+
|
|
3791
|
+
This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the Optimizer as training progresses.
|
|
3792
|
+
|
|
3793
|
+
param_group (dict) – specifies the parameters to be optimized and group-specific optimization options.
|
|
3794
|
+
|
|
3795
|
+
This method handles updating the shards on all partitions but needs to be called on all ranks. Calling this on a subset of the ranks will cause the training to hang because communication primitives are called depending on the managed parameters and expect all the ranks to participate on the same set of parameters.
|
|
3796
|
+
|
|
3797
|
+
Consolidate a list of state_dict s (one per rank) on the target rank.
|
|
3798
|
+
|
|
3799
|
+
to (int) – the rank that receives the optimizer states (default: 0).
|
|
3800
|
+
|
|
3801
|
+
RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt.
|
|
3802
|
+
|
|
3803
|
+
This needs to be called on all ranks.
|
|
3804
|
+
|
|
3805
|
+
Return default device.
|
|
3806
|
+
|
|
3807
|
+
Return the ZeRO join hook.
|
|
3808
|
+
|
|
3809
|
+
It enables training on uneven inputs by shadowing the collective communications in the optimizer step.
|
|
3810
|
+
|
|
3811
|
+
Gradients must be properly set before this hook is called.
|
|
3812
|
+
|
|
3813
|
+
kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs.
|
|
3814
|
+
|
|
3815
|
+
This hook does not support any keyword arguments; i.e. kwargs is unused.
|
|
3816
|
+
|
|
3817
|
+
Return process group.
|
|
3818
|
+
|
|
3819
|
+
Load the state pertaining to the given rank from the input state_dict, updating the local optimizer as needed.
|
|
3820
|
+
|
|
3821
|
+
state_dict (dict) – optimizer state; should be an object returned from a call to state_dict().
|
|
3822
|
+
|
|
3823
|
+
RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt.
|
|
3824
|
+
|
|
3825
|
+
Return the last global optimizer state known to this rank.
|
|
3826
|
+
|
|
3827
|
+
RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt; or if this method is called without a preceding call to consolidate_state_dict().
|
|
3828
|
+
|
|
3829
|
+
Perform a single optimizer step and syncs parameters across all ranks.
|
|
3830
|
+
|
|
3831
|
+
closure (Callable) – a closure that re-evaluates the model and returns the loss; optional for most optimizers.
|
|
3832
|
+
|
|
3833
|
+
Optional loss depending on the underlying local optimizer.
|
|
3834
|
+
|
|
3835
|
+
Any extra parameters are passed to the base optimizer as-is.
|
|
3836
|
+
|
|
3837
|
+
---
|
|
3838
|
+
|
|
3839
|
+
## Torch Distributed Elastic#
|
|
3840
|
+
|
|
3841
|
+
**URL:** https://pytorch.org/docs/stable/distributed.elastic.html
|
|
3842
|
+
|
|
3843
|
+
**Contents:**
|
|
3844
|
+
- Torch Distributed Elastic#
|
|
3845
|
+
- Get Started#
|
|
3846
|
+
- Documentation#
|
|
3847
|
+
|
|
3848
|
+
Created On: Jun 16, 2025 | Last Updated On: Jul 25, 2025
|
|
3849
|
+
|
|
3850
|
+
Makes distributed PyTorch fault-tolerant and elastic.
|
|
3851
|
+
|
|
3852
|
+
---
|
|
3853
|
+
|
|
3854
|
+
## Pipeline Parallelism#
|
|
3855
|
+
|
|
3856
|
+
**URL:** https://pytorch.org/docs/stable/distributed.pipelining.html
|
|
3857
|
+
|
|
3858
|
+
**Contents:**
|
|
3859
|
+
- Pipeline Parallelism#
|
|
3860
|
+
- Why Pipeline Parallel?#
|
|
3861
|
+
- What is torch.distributed.pipelining?#
|
|
3862
|
+
- Step 1: build PipelineStage#
|
|
3863
|
+
- Step 2: use PipelineSchedule for execution#
|
|
3864
|
+
- Options for Splitting a Model#
|
|
3865
|
+
- Option 1: splitting a model manually#
|
|
3866
|
+
- Option 2: splitting a model automatically#
|
|
3867
|
+
- Hugging Face Examples#
|
|
3868
|
+
- Technical Deep Dive#
|
|
3869
|
+
|
|
3870
|
+
Created On: Jun 16, 2025 | Last Updated On: Aug 13, 2025
|
|
3871
|
+
|
|
3872
|
+
torch.distributed.pipelining is currently in alpha state and under development. API changes may be possible. It was migrated from the PiPPy project.
|
|
3873
|
+
|
|
3874
|
+
Pipeline Parallelism is one of the primitive parallelism for deep learning. It allows the execution of a model to be partitioned such that multiple micro-batches can execute different parts of the model code concurrently. Pipeline parallelism can be an effective technique for:
|
|
3875
|
+
|
|
3876
|
+
bandwidth-limited clusters
|
|
3877
|
+
|
|
3878
|
+
large model inference
|
|
3879
|
+
|
|
3880
|
+
The above scenarios share a commonality that the computation per device cannot hide the communication of conventional parallelism, for example, the weight all-gather of FSDP.
|
|
3881
|
+
|
|
3882
|
+
While promising for scaling, pipelining is often difficult to implement because it needs to partition the execution of a model in addition to model weights. The partitioning of execution often requires intrusive code changes to your model. Another aspect of complexity comes from scheduling micro-batches in a distributed environment, with data flow dependency considered.
|
|
3883
|
+
|
|
3884
|
+
The pipelining package provides a toolkit that does said things automatically which allows easy implementation of pipeline parallelism on general models.
|
|
3885
|
+
|
|
3886
|
+
It consists of two parts: a splitting frontend and a distributed runtime. The splitting frontend takes your model code as-is, splits it up into “model partitions”, and captures the data-flow relationship. The distributed runtime executes the pipeline stages on different devices in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc.
|
|
3887
|
+
|
|
3888
|
+
Overall, the pipelining package provides the following features:
|
|
3889
|
+
|
|
3890
|
+
Splitting of model code based on simple specification.
|
|
3891
|
+
|
|
3892
|
+
Rich support for pipeline schedules, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS, and providing the infrastructure for writing customized schedules.
|
|
3893
|
+
|
|
3894
|
+
First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects).
|
|
3895
|
+
|
|
3896
|
+
Composability with other PyTorch parallel techniques such as data parallel (DDP, FSDP) or tensor parallel. The TorchTitan project demonstrates a “3D parallel” application on the Llama model.
|
|
3897
|
+
|
|
3898
|
+
Before we can use a PipelineSchedule, we need to create PipelineStage objects that wrap the part of the model running in that stage. The PipelineStage is responsible for allocating communication buffers and creating send/recv ops to communicate with its peers. It manages intermediate buffers e.g. for the outputs of forward that have not been consumed yet, and it provides a utility for running the backwards for the stage model.
|
|
3899
|
+
|
|
3900
|
+
A PipelineStage needs to know the input and output shapes for the stage model, so that it can correctly allocate communication buffers. The shapes must be static, e.g. at runtime the shapes can not change from step to step. A class PipeliningShapeError will be raised if runtime shapes do not match the expected shapes. When composing with other paralleisms or applying mixed precision, these techniques must be taken into account so the PipelineStage knows the correct shape (and dtype) for the output of the stage module at runtime.
|
|
3901
|
+
|
|
3902
|
+
Users may construct a PipelineStage instance directly, by passing in an nn.Module representing the portion of the model that should run on the stage. This may require changes to the original model code. See the example in Option 1: splitting a model manually.
|
|
3903
|
+
|
|
3904
|
+
Alternatively, the splitting frontend can use graph partitioning to split your model into a series of nn.Module automatically. This technique requires the model is traceable with torch.Export. Composability of the resulting nn.Module with other parallelism techniques is experimental, and may require some workarounds. Usage of this frontend may be more appealing if the user cannot easily change the model code. See Option 2: splitting a model automatically for more information.
|
|
3905
|
+
|
|
3906
|
+
We can now attach the PipelineStage to a pipeline schedule, and run the schedule with input data. Here is a GPipe example:
|
|
3907
|
+
|
|
3908
|
+
Note that the above code needs to be launched for each worker, thus we use a launcher service to launch multiple processes:
|
|
3909
|
+
|
|
3910
|
+
To directly construct a PipelineStage, the user is responsible for providing a single nn.Module instance that owns the relevant nn.Parameters and nn.Buffers, and defines a forward() method that executes the operations relevant for that stage. For example, a condensed version of the Transformer class defined in Torchtitan shows a pattern of building an easily partitionable model.
|
|
3911
|
+
|
|
3912
|
+
A model defined in this manner can be easily configured per stage by first initializing the whole model (using meta-device to avoid OOM errors), deleting undesired layers for that stage, and then creating a PipelineStage that wraps the model. For example:
|
|
3913
|
+
|
|
3914
|
+
When composing with other Data or Model parallelism techniques, output_args may also be required, if the output shape/dtype of the model chunk will be affected.
|
|
3915
|
+
|
|
3916
|
+
If you have a full model and do not want to spend time on modifying it into a sequence of “model partitions”, the pipeline API is here to help. Here is a brief example:
|
|
3917
|
+
|
|
3918
|
+
If we print the model, we can see multiple hierarchies, which makes it hard to split by hand:
|
|
3919
|
+
|
|
3920
|
+
Let us see how the pipeline API works:
|
|
3921
|
+
|
|
3922
|
+
The pipeline API splits your model given a split_spec, where SplitPoint.BEGINNING stands for adding a split point before execution of certain submodule in the forward function, and similarly, SplitPoint.END for split point after such.
|
|
3923
|
+
|
|
3924
|
+
If we print(pipe), we can see:
|
|
3925
|
+
|
|
3926
|
+
The “model partitions” are represented by submodules (submod_0, submod_1), each of which is reconstructed with original model operations, weights and hierarchies. In addition, a “root-level” forward function is reconstructed to capture the data flow between those partitions. Such data flow will be replayed by the pipeline runtime later, in a distributed fashion.
|
|
3927
|
+
|
|
3928
|
+
The Pipe object provides a method for retrieving the “model partitions”:
|
|
3929
|
+
|
|
3930
|
+
The returned stage_mod is a nn.Module, with which you can create an optimizer, save or load checkpoints, or apply other parallelisms.
|
|
3931
|
+
|
|
3932
|
+
Pipe also allows you to create a distributed stage runtime on a device given a ProcessGroup:
|
|
3933
|
+
|
|
3934
|
+
Alternatively, if you would like to build the stage runtime later after some modification to the stage_mod, you can use a functional version of the build_stage API. For example:
|
|
3935
|
+
|
|
3936
|
+
The pipeline frontend uses a tracer (torch.export) to capture your model into a single graph. If your model is not full-graph’able, you can use our manual frontend below.
|
|
3937
|
+
|
|
3938
|
+
In the PiPPy repo where this package was original created, we kept examples based on unmodified Hugging Face models. See the examples/huggingface directory.
|
|
3939
|
+
|
|
3940
|
+
First, the pipeline API turns our model into a directed acyclic graph (DAG) by tracing the model. It traces the model using torch.export – a PyTorch 2 full-graph capturing tool.
|
|
3941
|
+
|
|
3942
|
+
Then, it groups together the operations and parameters needed by a stage into a reconstructed submodule: submod_0, submod_1, …
|
|
3943
|
+
|
|
3944
|
+
Different from conventional submodule access methods like Module.children(), the pipeline API does not only cut the module structure of your model, but also the forward function of your model.
|
|
3945
|
+
|
|
3946
|
+
This is necessary because model structure like Module.children() merely captures information during Module.__init__(), and does not capture any information about Module.forward(). Said differently, Module.children() lacks information about the following aspects key to pipelininig:
|
|
3947
|
+
|
|
3948
|
+
Execution order of child modules in forward
|
|
3949
|
+
|
|
3950
|
+
Activation flows between child modules
|
|
3951
|
+
|
|
3952
|
+
Whether there are any functional operators between child modules (for example, relu or add operations will not be captured by Module.children()).
|
|
3953
|
+
|
|
3954
|
+
The pipeline API, on the contrary, makes sure that the forward behavior is truly preserved. It also captures the activation flow between the partitions, helping the distributed runtime to make correct send/receive calls without human intervention.
|
|
3955
|
+
|
|
3956
|
+
Another flexibility of the pipeline API is that split points can be at arbitrary levels within your model hierarchy. In the split partitions, the original model hierarchy related to that partition will be reconstructed at no cost to you. At a result, fully-qualified names (FQNs) pointing to a submodule or parameter would be still valid, and services that relies on FQNs (such as FSDP, TP or checkpointing) can still run with your partitioned modules with almost zero code change.
|
|
3957
|
+
|
|
3958
|
+
You can implement your own pipeline schedule by extending one of the following two class:
|
|
3959
|
+
|
|
3960
|
+
PipelineScheduleSingle
|
|
3961
|
+
|
|
3962
|
+
PipelineScheduleMulti
|
|
3963
|
+
|
|
3964
|
+
PipelineScheduleSingle is for schedules that assigns only one stage per rank. PipelineScheduleMulti is for schedules that assigns multiple stages per rank.
|
|
3965
|
+
|
|
3966
|
+
For example, ScheduleGPipe and Schedule1F1B are subclasses of PipelineScheduleSingle. Whereas, ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble, and ScheduleZBVZeroBubble are subclasses of PipelineScheduleMulti.
|
|
3967
|
+
|
|
3968
|
+
You can turn on additional logging using the TORCH_LOGS environment variable from torch._logging:
|
|
3969
|
+
|
|
3970
|
+
TORCH_LOGS=+pp will display logging.DEBUG messages and all levels above it.
|
|
3971
|
+
|
|
3972
|
+
TORCH_LOGS=pp will display logging.INFO messages and above.
|
|
3973
|
+
|
|
3974
|
+
TORCH_LOGS=-pp will display logging.WARNING messages and above.
|
|
3975
|
+
|
|
3976
|
+
The following set of APIs transform your model into a pipeline representation.
|
|
3977
|
+
|
|
3978
|
+
Enum representing the points at which a split can occur in the execution of a submodule. :ivar BEGINNING: Represents adding a split point before the execution of a certain submodule in the forward function. :ivar END: Represents adding a split point after the execution of a certain submodule in the forward function.
|
|
3979
|
+
|
|
3980
|
+
Split a module based on a specification.
|
|
3981
|
+
|
|
3982
|
+
See Pipe for more details.
|
|
3983
|
+
|
|
3984
|
+
module (Module) – The module to be split.
|
|
3985
|
+
|
|
3986
|
+
mb_args (tuple[Any, ...]) – Example positional inputs, in micro-batch form.
|
|
3987
|
+
|
|
3988
|
+
mb_kwargs (Optional[dict[str, Any]]) – Example keyword inputs, in micro-batch form. (default: None)
|
|
3989
|
+
|
|
3990
|
+
split_spec (Optional[dict[str, torch.distributed.pipelining._IR.SplitPoint]]) – A dictionary using submodule names as split marker. (default: None)
|
|
3991
|
+
|
|
3992
|
+
split_policy (Optional[Callable[[GraphModule], GraphModule]]) – The policy to use for splitting the module. (default: None)
|
|
3993
|
+
|
|
3994
|
+
A pipeline representation of class Pipe.
|
|
3995
|
+
|
|
3996
|
+
pipe_split is a special operator that is used to mark the boundary between stages in a module. It is used to split the module into stages. It is a no-op if your annotated module is run eagerly.
|
|
3997
|
+
|
|
3998
|
+
The above example will be split into two stages.
|
|
3999
|
+
|
|
4000
|
+
Class used to specify chunking of inputs
|
|
4001
|
+
|
|
4002
|
+
Given a sequence of args and kwargs, split them into a number of chunks according to their respective chunking specs.
|
|
4003
|
+
|
|
4004
|
+
args (tuple[Any, ...]) – Tuple of args
|
|
4005
|
+
|
|
4006
|
+
kwargs (Optional[dict[str, Any]]) – Dict of kwargs
|
|
4007
|
+
|
|
4008
|
+
chunks (int) – Number of chunks to split the args and kwargs into
|
|
4009
|
+
|
|
4010
|
+
args_chunk_spec (Optional[tuple[torch.distributed.pipelining.microbatch.TensorChunkSpec, ...]]) – chunking specs for args, in same shape as args
|
|
4011
|
+
|
|
4012
|
+
kwargs_chunk_spec (Optional[dict[str, torch.distributed.pipelining.microbatch.TensorChunkSpec]]) – chunking specs for kwargs, in same shape as kwargs
|
|
4013
|
+
|
|
4014
|
+
List of sharded args kwargs_split: List of sharded kwargs
|
|
4015
|
+
|
|
4016
|
+
Given a list of chunks, merge them into a single value according to the chunk spec.
|
|
4017
|
+
|
|
4018
|
+
chunks (list[Any]) – list of chunks
|
|
4019
|
+
|
|
4020
|
+
chunk_spec – Chunking spec for the chunks
|
|
4021
|
+
|
|
4022
|
+
A class representing a pipeline stage in a pipeline parallelism setup.
|
|
4023
|
+
|
|
4024
|
+
PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from one chunk feed into inputs of the next chunk, with no skip connections.
|
|
4025
|
+
|
|
4026
|
+
PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to stage1 and so forth, in linear order. To bypass shape inference, pass the input_args and output_args to each PipelineStage instance.
|
|
4027
|
+
|
|
4028
|
+
submodule (nn.Module) – The PyTorch module wrapped by this stage.
|
|
4029
|
+
|
|
4030
|
+
stage_index (int) – The ID of this stage.
|
|
4031
|
+
|
|
4032
|
+
num_stages (int) – The total number of stages.
|
|
4033
|
+
|
|
4034
|
+
device (torch.device) – The device where this stage is located.
|
|
4035
|
+
|
|
4036
|
+
input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – The input arguments for the submodule.
|
|
4037
|
+
|
|
4038
|
+
output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – The output arguments for the submodule.
|
|
4039
|
+
|
|
4040
|
+
group (dist.ProcessGroup, optional) – The process group for distributed training. If None, default group.
|
|
4041
|
+
|
|
4042
|
+
dw_builder (Optional[Callable[[], Callable[..., None]]) – If provided, dw_builder will build a new dw_runner function that will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules.
|
|
4043
|
+
|
|
4044
|
+
Create a pipeline stage given a stage_module to be wrapped by this stage and pipeline information.
|
|
4045
|
+
|
|
4046
|
+
stage_module (torch.nn.Module) – the module to be wrapped by this stage
|
|
4047
|
+
|
|
4048
|
+
stage_index (int) – the index of this stage in the pipeline
|
|
4049
|
+
|
|
4050
|
+
pipe_info (PipeInfo) – information about the pipeline, can be retrieved by pipe.info()
|
|
4051
|
+
|
|
4052
|
+
device (torch.device) – the device to be used by this stage
|
|
4053
|
+
|
|
4054
|
+
group (Optional[dist.ProcessGroup]) – the process group to be used by this stage
|
|
4055
|
+
|
|
4056
|
+
a pipeline stage that can run with PipelineSchedules.
|
|
4057
|
+
|
|
4058
|
+
The GPipe schedule. Will go through all the microbatches in a fill-drain manner.
|
|
4059
|
+
|
|
4060
|
+
The 1F1B schedule. Will perform one forward and one backward on the microbatches in steady state.
|
|
4061
|
+
|
|
4062
|
+
The Interleaved 1F1B schedule. See https://arxiv.org/pdf/2104.04473 for details. Will perform one forward and one backward on the microbatches in steady state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch (also called “depth first”).
|
|
4063
|
+
|
|
4064
|
+
This schedule is mostly similar to the original paper. It differs by being relaxing the requirement of num_microbatch % pp_size == 0. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and it works as long as n_microbatches % num_rounds is 0. As a few examples, support
|
|
4065
|
+
|
|
4066
|
+
pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
|
|
4067
|
+
|
|
4068
|
+
pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
|
|
4069
|
+
|
|
4070
|
+
Breadth-First Pipeline Parallelism. See https://arxiv.org/abs/2211.05953 for details. Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. What is different is that when microbatches are ready for multiple local stages, Loops BFS will prioritizes the earlier stage, running all available microbatches at once.
|
|
4071
|
+
|
|
4072
|
+
The Interleaved Zero Bubble schedule. See https://arxiv.org/pdf/2401.10241 for details. Will perform one forward and one backward on inputs for the microbatches in steady state and supports multiple stages per rank. Uses the backward for weights to fill in the pipeline bubble.
|
|
4073
|
+
|
|
4074
|
+
In particular this is implementing the ZB1P schedule in the paper.
|
|
4075
|
+
|
|
4076
|
+
The Zero Bubble schedule (ZBV variant). See https://arxiv.org/pdf/2401.10241 Section 6 for details.
|
|
4077
|
+
|
|
4078
|
+
This schedules requires exactly two stages per rank.
|
|
4079
|
+
|
|
4080
|
+
This schedule will perform one forward and one backward on inputs for the microbatches in steady state and supports multiple stages per rank. Uses backward with respect to weights to fill in the pipeline bubble.
|
|
4081
|
+
|
|
4082
|
+
This ZB-V schedule would have the “zero bubble” property only if time forward == time backward input == time backward weights. In practice, this is not likely true for real models so alternatively a greedy scheduler could be implemented for unequal/unbalanced time.
|
|
4083
|
+
|
|
4084
|
+
The DualPipeV schedule. A more efficient schedule variant based on the DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
|
|
4085
|
+
|
|
4086
|
+
Based on the open sourced code from deepseek-ai/DualPipe
|
|
4087
|
+
|
|
4088
|
+
Base class for single-stage schedules. Implements the step method. Derived classes should implement _step_microbatches.
|
|
4089
|
+
|
|
4090
|
+
Gradients are scaled by num_microbatches depending on the scale_grads argument, defaulting to True. This setting should match the configuration of your loss_fn, which may either average losses (scale_grads=True) or sum losses (scale_grads=False).
|
|
4091
|
+
|
|
4092
|
+
Run one iteration of the pipeline schedule with whole-batch input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation.
|
|
4093
|
+
|
|
4094
|
+
args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch.
|
|
4095
|
+
|
|
4096
|
+
Base class for multi-stage schedules. Implements the step method.
|
|
4097
|
+
|
|
4098
|
+
Gradients are scaled by num_microbatches depending on the scale_grads argument, defaulting to True. This setting should match the configuration of your loss_fn, which may either average losses (scale_grads=True) or sum losses (scale_grads=False).
|
|
4099
|
+
|
|
4100
|
+
Run one iteration of the pipeline schedule with whole-batch input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation.
|
|
4101
|
+
|
|
4102
|
+
args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch.
|
|
4103
|
+
|
|
4104
|
+
---
|
|
4105
|
+
|
|
4106
|
+
## Tensor Parallelism - torch.distributed.tensor.parallel#
|
|
4107
|
+
|
|
4108
|
+
**URL:** https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
|
4109
|
+
|
|
4110
|
+
**Contents:**
|
|
4111
|
+
- Tensor Parallelism - torch.distributed.tensor.parallel#
|
|
4112
|
+
|
|
4113
|
+
Created On: Jun 13, 2025 | Last Updated On: Jun 13, 2025
|
|
4114
|
+
|
|
4115
|
+
Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor)[https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md] and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism.
|
|
4116
|
+
|
|
4117
|
+
Tensor Parallelism APIs are experimental and subject to change.
|
|
4118
|
+
|
|
4119
|
+
The entrypoint to parallelize your nn.Module using Tensor Parallelism is:
|
|
4120
|
+
|
|
4121
|
+
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
|
4122
|
+
|
|
4123
|
+
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains ParallelStyle, which indicates how user wants the module or sub_module to be parallelized.
|
|
4124
|
+
|
|
4125
|
+
User can also specify different parallel style per module fully qualified name (FQN).
|
|
4126
|
+
|
|
4127
|
+
Note that parallelize_module only accepts a 1-D DeviceMesh, if you have a 2-D or N-D DeviceMesh, slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. device_mesh["tp"])
|
|
4128
|
+
|
|
4129
|
+
module (nn.Module) – Module to be parallelized.
|
|
4130
|
+
|
|
4131
|
+
device_mesh (DeviceMesh, optional) – Object which describes the mesh topology of devices for the DTensor. If not specified, the call must be under a DeviceMesh context.
|
|
4132
|
+
|
|
4133
|
+
parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]], optional) – The plan used to parallelize the module. It can be either a ParallelStyle object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding ParallelStyle object. If not specified, the call will do nothing at the moment.
|
|
4134
|
+
|
|
4135
|
+
src_data_rank (int, optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default, we use group_rank=0 on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None explicitly, parallelize_module() simply uses its local data instead of trying to preserve the single-device semantic via scatter/broadcast. Default: 0
|
|
4136
|
+
|
|
4137
|
+
A nn.Module object parallelized.
|
|
4138
|
+
|
|
4139
|
+
For complex module architecture like Attention, MLP layers, we recommend composing different ParallelStyles together (i.e. ColwiseParallel and RowwiseParallel) and pass as a parallelize_plan, to achieves the desired sharding computation.
|
|
4140
|
+
|
|
4141
|
+
Tensor Parallelism supports the following parallel styles:
|
|
4142
|
+
|
|
4143
|
+
Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention)
|
|
4144
|
+
|
|
4145
|
+
input_layouts (Placement, optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be replicated.
|
|
4146
|
+
|
|
4147
|
+
output_layouts (Placement, optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
|
|
4148
|
+
|
|
4149
|
+
use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: True.
|
|
4150
|
+
|
|
4151
|
+
A ParallelStyle object that represents Colwise sharding of the nn.Module.
|
|
4152
|
+
|
|
4153
|
+
By default ColwiseParallel output is sharded on the last dimension if the output_layouts not specified, if there’re operators that require specific tensor shape (i.e. before the paired RowwiseParallel), keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
|
|
4154
|
+
|
|
4155
|
+
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention)
|
|
4156
|
+
|
|
4157
|
+
input_layouts (Placement, optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
|
|
4158
|
+
|
|
4159
|
+
output_layouts (Placement, optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is replicated.
|
|
4160
|
+
|
|
4161
|
+
use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: True.
|
|
4162
|
+
|
|
4163
|
+
A ParallelStyle object that represents Rowwise sharding of the nn.Module.
|
|
4164
|
+
|
|
4165
|
+
SequenceParallel replicates a compatible nn.Module parameters and runs the sharded computation with input sharded on the sequence dimension. This currently supports nn.LayerNorm, nn.Dropout, and the RMSNorm python implementation
|
|
4166
|
+
|
|
4167
|
+
This style implements the operation that is described in the paper Reducing Activation Recomputation in Large Transformer Models
|
|
4168
|
+
|
|
4169
|
+
If the input passed in to this nn.Module is a torch.Tensor, it assumes that the input is already sharded on the sequence dimension and converts the input to a DTensor sharded on the sequence dimension. If the input passed in to this nn.Module is already a DTensor but is not sharded on the sequence dimension, it would redistribute the input to be sharded on the sequence dimension.
|
|
4170
|
+
|
|
4171
|
+
The output of the nn.Module will be sharded on the sequence dimension.
|
|
4172
|
+
|
|
4173
|
+
sequence_dim (int, optional) – The sequence dimension of the input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor that is sharded on the sequence dimension, default: 1.
|
|
4174
|
+
|
|
4175
|
+
use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: False.
|
|
4176
|
+
|
|
4177
|
+
A ParallelStyle object that represents Sequence Parallel of the nn.Module.
|
|
4178
|
+
|
|
4179
|
+
SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. nn.LayerNorm or RMSNorm, and they by default have ones initialization). If you have custom inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated.
|
|
4180
|
+
|
|
4181
|
+
To simply configure the nn.Module’s inputs and outputs with DTensor layouts and perform necessary layout redistributions, without distribute the module parameters to DTensors, the following ParallelStyle s can be used in the parallelize_plan when calling parallelize_module:
|
|
4182
|
+
|
|
4183
|
+
Configure the nn.Module’s inputs to convert the input tensors of the nn.Module to DTensors at runtime according to input_layouts, and perform layout redistribution according to the desired_input_layouts.
|
|
4184
|
+
|
|
4185
|
+
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder. default: None.
|
|
4186
|
+
|
|
4187
|
+
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with input_layouts. default: None.
|
|
4188
|
+
|
|
4189
|
+
input_kwarg_layouts (Dict[str, Placement]) – The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. default: None
|
|
4190
|
+
|
|
4191
|
+
desired_input_kwarg_layouts – (Dict[str, Placement]): The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. default: None.
|
|
4192
|
+
|
|
4193
|
+
use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module inputs, default: False.
|
|
4194
|
+
|
|
4195
|
+
A ParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs.
|
|
4196
|
+
|
|
4197
|
+
Configure the nn.Module’s outputs to convert the output tensors of the nn.Module to DTensors at runtime according to output_layouts, and perform layout redistribution according to the desired_output_layouts.
|
|
4198
|
+
|
|
4199
|
+
output_layouts (Union[Placement, Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are torch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder.
|
|
4200
|
+
|
|
4201
|
+
desired_output_layouts (Union[Placement, Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts.
|
|
4202
|
+
|
|
4203
|
+
use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module outputs, default: True.
|
|
4204
|
+
|
|
4205
|
+
A ParallelStyle object that prepares the sharding layouts of the nn.Module’s outputs.
|
|
4206
|
+
|
|
4207
|
+
Configure the nn.Module’s inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module to DTensors at runtime according to input_layouts (and output_layouts, respectively), and perform layout redistribution according to the desired_input_layouts (and desired_output_layouts, respectively). This is a combination of PrepareModuleInput and PrepareModuleOutput.
|
|
4208
|
+
|
|
4209
|
+
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder. default: None.
|
|
4210
|
+
|
|
4211
|
+
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with input_layouts. default: None.
|
|
4212
|
+
|
|
4213
|
+
input_kwarg_layouts (Dict[str, Placement]) – The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. default: None
|
|
4214
|
+
|
|
4215
|
+
desired_input_kwarg_layouts – (Dict[str, Placement]): The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. default: None.
|
|
4216
|
+
|
|
4217
|
+
use_local_input (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module inputs, default: False.
|
|
4218
|
+
|
|
4219
|
+
output_layouts (Union[Placement, Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are torch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder.
|
|
4220
|
+
|
|
4221
|
+
desired_output_layouts (Union[Placement, Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts.
|
|
4222
|
+
|
|
4223
|
+
use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module outputs, default: True.
|
|
4224
|
+
|
|
4225
|
+
A ParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs and outputs.
|
|
4226
|
+
|
|
4227
|
+
when using the Shard(dim) as the input/output layouts for the above ParallelStyle s, we assume the input/output activation tensors are evenly sharded on the tensor dimension dim on the DeviceMesh that TP operates on. For instance, since RowwiseParallel accepts input that is sharded on the last dimension, it assumes the input tensor has already been evenly sharded on the last dimension. For the case of uneven sharded activation tensors, one could pass in DTensor directly to the partitioned modules, and use use_local_output=False to return DTensor after each ParallelStyle, where DTensor could track the uneven sharding information.
|
|
4228
|
+
|
|
4229
|
+
For models like Transformer, we recommend users to use ColwiseParallel and RowwiseParallel together in the parallelize_plan for achieve the desired sharding for the entire model (i.e. Attention and MLP).
|
|
4230
|
+
|
|
4231
|
+
Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager:
|
|
4232
|
+
|
|
4233
|
+
A context manager that enables loss parallelism, where efficient parallelized loss computation can be performed when the input is sharded on the class dimension. Currently only the cross-entropy loss is supported.
|
|
4234
|
+
|
|
4235
|
+
Within this context manager, one can use cross_entropy() or CrossEntropyLoss as usual, with the following assumptions on the input parameters. The corresponding backward() call, if any, also needs to happen under this context manager.
|
|
4236
|
+
|
|
4237
|
+
input (DTensor) – Input logits. Assumed to be sharded on the class dimension.
|
|
4238
|
+
|
|
4239
|
+
target (Union[torch.Tensor, DTensor]) – Must be ground truth class indices (class probabilities currently not supported). Assumed to be replicated across the DeviceMesh.
|
|
4240
|
+
|
|
4241
|
+
weight (Union[torch.Tensor, DTensor], optional) – If given, assumed to be replicated across the DeviceMesh.
|
|
4242
|
+
|
|
4243
|
+
label_smoothing – Currently not supported.
|
|
4244
|
+
|
|
4245
|
+
A replicated DTensor.
|
|
4246
|
+
|
|
4247
|
+
A sharded DTensor is manually created here to showcase the usage. In practice, it is usually the output of a TP module.
|
|
4248
|
+
|
|
4249
|
+
---
|