d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PipelineScheduleInferenceConfig(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Configuration for inference-only pipeline execution.
|
|
9
|
+
|
|
10
|
+
This schedule runs all forward passes sequentially without any backward passes.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
schedule: Literal["inference"] = "inference"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PipelineScheduleGPipeConfig(BaseModel):
|
|
17
|
+
"""
|
|
18
|
+
Configuration for GPipe execution.
|
|
19
|
+
|
|
20
|
+
This assumes a single stage per rank and processes all microbatches for the
|
|
21
|
+
forward pass before switching to the backward pass.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
schedule: Literal["gpipe"] = "gpipe"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PipelineScheduleLoopedBFSConfig(BaseModel):
|
|
28
|
+
"""
|
|
29
|
+
Configuration for Looped Breadth-First Search execution.
|
|
30
|
+
|
|
31
|
+
Similar to GPipe, but supports multiple stages per rank (virtualization).
|
|
32
|
+
It executes all available work for a specific stage before moving to the next.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
schedule: Literal["looped_bfs"] = "looped_bfs"
|
|
36
|
+
|
|
37
|
+
num_stages_per_rank: int
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PipelineSchedule1F1BConfig(BaseModel):
|
|
41
|
+
"""
|
|
42
|
+
Configuration for Interleaved 1F1B and Interleaved Zero Bubble execution.
|
|
43
|
+
|
|
44
|
+
Supports assigning multiple stages per rank and sharding backward to dI and dW
|
|
45
|
+
to reduce pipeline bubbles.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
schedule: Literal["1f1b"] = "1f1b"
|
|
49
|
+
|
|
50
|
+
num_stages_per_rank: int
|
|
51
|
+
zero_bubble: bool
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class PipelineScheduleZeroBubbleVConfig(BaseModel):
|
|
55
|
+
"""
|
|
56
|
+
Configuration for Zero Bubble V (ZBV) execution.
|
|
57
|
+
|
|
58
|
+
A specialized V-shape topology schedule that splits backward passes into
|
|
59
|
+
Input and Weight gradients to maximize overlap. Requires exactly 2 stages per rank.
|
|
60
|
+
"""
|
|
61
|
+
schedule: Literal["zero_bubble_v"] = "zero_bubble_v"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class PipelineScheduleDualPipeVConfig(BaseModel):
|
|
65
|
+
"""
|
|
66
|
+
Configuration for DualPipeV execution.
|
|
67
|
+
|
|
68
|
+
A bidirectional pipeline schedule for high-throughput training, utilizing
|
|
69
|
+
V-shape topology and reciprocal forward/backward scheduling.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
schedule: Literal["dual_pipe_v"] = "dual_pipe_v"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
AnyPipelineScheduleConfig = Annotated[
|
|
76
|
+
PipelineScheduleInferenceConfig |
|
|
77
|
+
PipelineScheduleGPipeConfig |
|
|
78
|
+
PipelineScheduleLoopedBFSConfig |
|
|
79
|
+
PipelineSchedule1F1BConfig |
|
|
80
|
+
PipelineScheduleZeroBubbleVConfig |
|
|
81
|
+
PipelineScheduleDualPipeVConfig,
|
|
82
|
+
Field(discriminator="schedule")
|
|
83
|
+
]
|
|
84
|
+
"""Union of all supported pipeline schedule configuration types.
|
|
85
|
+
|
|
86
|
+
This type alias uses a Pydantic discriminator on the ``schedule`` field to allow
|
|
87
|
+
polymorphic validation and serialization of specific schedule configs (e.g.
|
|
88
|
+
Inference, GPipe, 1F1B, ZeroBubble, etc.).
|
|
89
|
+
"""
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from ...core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
8
|
+
from ..api import PipelineSchedule, PipelineStageInfo
|
|
9
|
+
from ..infra.schedule.component.program import (
|
|
10
|
+
build_stage_to_host_rank_topology,
|
|
11
|
+
invert_stage_to_host_rank_topology,
|
|
12
|
+
)
|
|
13
|
+
from ..infra.schedule.component.runtime import PipelineScheduleExecutor
|
|
14
|
+
from ..infra.stage import PipelineStage
|
|
15
|
+
from .config import (
|
|
16
|
+
AnyPipelineScheduleConfig,
|
|
17
|
+
)
|
|
18
|
+
from .registry import PIPELINE_PROGRAM_REGISTRY
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass(kw_only=True)
|
|
22
|
+
class PipelineScheduleInfo:
|
|
23
|
+
"""Contains the built pipeline schedule and rank-specific metadata."""
|
|
24
|
+
|
|
25
|
+
schedule: PipelineSchedule
|
|
26
|
+
has_first_stage: bool
|
|
27
|
+
has_last_stage: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def build_schedule(
|
|
31
|
+
dist_context: DistributedContext,
|
|
32
|
+
n_microbatches: int,
|
|
33
|
+
schedule_config: AnyPipelineScheduleConfig,
|
|
34
|
+
model_provider: Callable[[PipelineStageInfo], nn.Module],
|
|
35
|
+
loss_fn: Callable[[dict[str, torch.Tensor], int], torch.Tensor] | None,
|
|
36
|
+
) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
|
|
37
|
+
"""
|
|
38
|
+
Constructs the pipeline schedule and instantiates model stages.
|
|
39
|
+
|
|
40
|
+
This function coordinates the creation of the distributed pipeline. It:
|
|
41
|
+
1. Selects the appropriate `PipelineProgramBuilder` based on the config.
|
|
42
|
+
2. Calculates the global stage topology mapping stages to ranks.
|
|
43
|
+
3. Instantiates the local model stages for the current rank using `model_provider`.
|
|
44
|
+
4. Wraps models in `PipelineStage` containers.
|
|
45
|
+
5. Generates the execution program (action list).
|
|
46
|
+
6. Builds the runtime executor.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
dist_context: The distributed context.
|
|
50
|
+
n_microbatches: Number of microbatches per global step.
|
|
51
|
+
schedule_config: Configuration object determining the schedule strategy.
|
|
52
|
+
model_provider: A factory function that accepts stage info and returns an `nn.Module`
|
|
53
|
+
for that specific stage.
|
|
54
|
+
loss_fn: Optional loss function. Required if training (backward pass needed).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A tuple containing:
|
|
58
|
+
1. `PipelineScheduleInfo`: The executable schedule and metadata.
|
|
59
|
+
2. `list[nn.Module]`: The local PyTorch modules created for this rank.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
program_builder = PIPELINE_PROGRAM_REGISTRY.program_for(schedule_config)
|
|
63
|
+
mesh = dist_context.mesh_for(REGULAR_DOMAIN)["pp"]
|
|
64
|
+
|
|
65
|
+
num_stages = program_builder.num_stages_per_rank * mesh.size()
|
|
66
|
+
|
|
67
|
+
stage_to_host = build_stage_to_host_rank_topology(
|
|
68
|
+
num_stages=num_stages,
|
|
69
|
+
pp_size=mesh.size(),
|
|
70
|
+
style=program_builder.topology_style
|
|
71
|
+
)
|
|
72
|
+
host_to_stage = invert_stage_to_host_rank_topology(stage_to_host)
|
|
73
|
+
this_rank_stages = host_to_stage[mesh.get_local_rank()]
|
|
74
|
+
|
|
75
|
+
stages = []
|
|
76
|
+
modules = []
|
|
77
|
+
has_first_stage = False
|
|
78
|
+
has_last_stage = False
|
|
79
|
+
|
|
80
|
+
for stage_idx in this_rank_stages:
|
|
81
|
+
stage_info = PipelineStageInfo(
|
|
82
|
+
num_stages=num_stages,
|
|
83
|
+
current_stage=stage_idx
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if stage_info.is_current_stage_first:
|
|
87
|
+
has_first_stage = True
|
|
88
|
+
if stage_info.is_current_stage_last:
|
|
89
|
+
has_last_stage = True
|
|
90
|
+
|
|
91
|
+
model = model_provider(stage_info)
|
|
92
|
+
modules.append(model)
|
|
93
|
+
stage = PipelineStage(
|
|
94
|
+
info=stage_info,
|
|
95
|
+
module=model,
|
|
96
|
+
group=mesh.get_group(),
|
|
97
|
+
stage_to_host_topology=stage_to_host
|
|
98
|
+
)
|
|
99
|
+
stages.append(stage)
|
|
100
|
+
|
|
101
|
+
program = program_builder.compose(num_microbatches=n_microbatches, pp_size=mesh.size())
|
|
102
|
+
schedule = PipelineScheduleExecutor(
|
|
103
|
+
dist_context=dist_context,
|
|
104
|
+
stages=stages,
|
|
105
|
+
num_microbatches=n_microbatches,
|
|
106
|
+
loss_fn=loss_fn,
|
|
107
|
+
program=program
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return PipelineScheduleInfo(
|
|
111
|
+
schedule=schedule,
|
|
112
|
+
has_first_stage=has_first_stage,
|
|
113
|
+
has_last_stage=has_last_stage
|
|
114
|
+
), modules
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import TypeVar, cast
|
|
3
|
+
|
|
4
|
+
from d9d.pipelining.factory import (
|
|
5
|
+
AnyPipelineScheduleConfig,
|
|
6
|
+
PipelineSchedule1F1BConfig,
|
|
7
|
+
PipelineScheduleDualPipeVConfig,
|
|
8
|
+
PipelineScheduleGPipeConfig,
|
|
9
|
+
PipelineScheduleInferenceConfig,
|
|
10
|
+
PipelineScheduleLoopedBFSConfig,
|
|
11
|
+
PipelineScheduleZeroBubbleVConfig,
|
|
12
|
+
)
|
|
13
|
+
from d9d.pipelining.infra.schedule.component.program import PipelineProgramBuilder
|
|
14
|
+
from d9d.pipelining.infra.schedule.program import (
|
|
15
|
+
DualPipeVPipelineProgramBuilder,
|
|
16
|
+
Interleaved1F1BPipelineProgramBuilder,
|
|
17
|
+
LoopedBFSPipelineProgramBuilder,
|
|
18
|
+
ZeroBubbleVPipelineProgramBuilder,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
TConfig = TypeVar("TConfig", bound=AnyPipelineScheduleConfig)
|
|
22
|
+
|
|
23
|
+
TRegistryDict = dict[
|
|
24
|
+
type[AnyPipelineScheduleConfig],
|
|
25
|
+
Callable[[AnyPipelineScheduleConfig], PipelineProgramBuilder]
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
TBoundRegistryFn = Callable[[TConfig], PipelineProgramBuilder]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PipelineProgramRegistry:
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self._registry: TRegistryDict = {}
|
|
34
|
+
|
|
35
|
+
def register_program(
|
|
36
|
+
self, config_cls: type[TConfig]
|
|
37
|
+
) -> Callable[[TBoundRegistryFn], TBoundRegistryFn]:
|
|
38
|
+
def decorator(func: TBoundRegistryFn) -> TBoundRegistryFn:
|
|
39
|
+
config_cls_any = cast(type[AnyPipelineScheduleConfig], config_cls)
|
|
40
|
+
self._registry[config_cls_any] = func
|
|
41
|
+
return func
|
|
42
|
+
|
|
43
|
+
return decorator
|
|
44
|
+
|
|
45
|
+
def program_for(self, config: AnyPipelineScheduleConfig) -> PipelineProgramBuilder:
|
|
46
|
+
program_fn = self._registry[type(config)]
|
|
47
|
+
program = program_fn(config)
|
|
48
|
+
return program
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
PIPELINE_PROGRAM_REGISTRY = PipelineProgramRegistry()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleGPipeConfig)
|
|
55
|
+
def _build_gpipe(_: PipelineScheduleGPipeConfig) -> PipelineProgramBuilder:
|
|
56
|
+
return LoopedBFSPipelineProgramBuilder(num_stages_per_rank=1, inference_mode=False)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleInferenceConfig)
|
|
60
|
+
def _build_inference(_: PipelineScheduleInferenceConfig) -> PipelineProgramBuilder:
|
|
61
|
+
return LoopedBFSPipelineProgramBuilder(num_stages_per_rank=1, inference_mode=True)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleLoopedBFSConfig)
|
|
65
|
+
def _build_looped_bfs(cfg: PipelineScheduleLoopedBFSConfig) -> PipelineProgramBuilder:
|
|
66
|
+
return LoopedBFSPipelineProgramBuilder(num_stages_per_rank=cfg.num_stages_per_rank, inference_mode=False)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@PIPELINE_PROGRAM_REGISTRY.register_program(PipelineSchedule1F1BConfig)
|
|
70
|
+
def _build_1f1b(cfg: PipelineSchedule1F1BConfig) -> PipelineProgramBuilder:
|
|
71
|
+
return Interleaved1F1BPipelineProgramBuilder(num_stages_per_rank=cfg.num_stages_per_rank,
|
|
72
|
+
enable_zero_bubble=cfg.zero_bubble)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleDualPipeVConfig)
|
|
76
|
+
def _build_dual_pipe_v(_: PipelineScheduleDualPipeVConfig) -> PipelineProgramBuilder:
|
|
77
|
+
return DualPipeVPipelineProgramBuilder()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleZeroBubbleVConfig)
|
|
81
|
+
def _build_zero_bubble_v(_: PipelineScheduleZeroBubbleVConfig) -> PipelineProgramBuilder:
|
|
82
|
+
return ZeroBubbleVPipelineProgramBuilder()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline Schedule Building Components.
|
|
3
|
+
|
|
4
|
+
This package provides the core building blocks and compiler passes used to generate
|
|
5
|
+
execution schedules for distributed pipelines.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .base import PipelineProgramBuilder
|
|
9
|
+
from .communications import add_communication_ops
|
|
10
|
+
from .topology import (
|
|
11
|
+
ScheduleStyle,
|
|
12
|
+
build_stage_to_host_rank_topology,
|
|
13
|
+
invert_stage_to_host_rank_topology,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"PipelineProgramBuilder",
|
|
18
|
+
"ScheduleStyle",
|
|
19
|
+
"add_communication_ops",
|
|
20
|
+
"build_stage_to_host_rank_topology",
|
|
21
|
+
"invert_stage_to_host_rank_topology"
|
|
22
|
+
]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
from ..program.topology import ScheduleStyle
|
|
4
|
+
from ..runtime import ActionBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PipelineProgramBuilder(abc.ABC):
|
|
8
|
+
"""Abstract interface for building pipeline execution schedules."""
|
|
9
|
+
|
|
10
|
+
@abc.abstractmethod
|
|
11
|
+
def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
|
|
12
|
+
"""
|
|
13
|
+
Generates the execution program for all ranks in the pipeline.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
num_microbatches: Number of microbatches per step.
|
|
17
|
+
pp_size: Number of pipeline parallel ranks.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
A dictionary mapping rank indices to their list of sequential actions.
|
|
21
|
+
"""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def num_stages_per_rank(self) -> int:
|
|
27
|
+
"""Returns the number of model stages designated for each rank."""
|
|
28
|
+
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def topology_style(self) -> ScheduleStyle:
|
|
34
|
+
"""Returns the topology style strategy used to assign stages to ranks."""
|
|
35
|
+
...
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import dataclasses
|
|
3
|
+
|
|
4
|
+
from ..runtime.action import (
|
|
5
|
+
ActionBase,
|
|
6
|
+
BackwardFullInputComputeAction,
|
|
7
|
+
BackwardReceiveAction,
|
|
8
|
+
BackwardSendAction,
|
|
9
|
+
ComposeAction,
|
|
10
|
+
ForwardComputeAction,
|
|
11
|
+
ForwardReceiveAction,
|
|
12
|
+
ForwardSendAction,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_sub_actions(action: ActionBase) -> tuple[ActionBase, ...]:
|
|
17
|
+
if isinstance(action, ComposeAction):
|
|
18
|
+
return action.actions
|
|
19
|
+
return (action,)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _check_action_communication_dependencies_fulfilled(
|
|
23
|
+
action: ActionBase,
|
|
24
|
+
rank_events: set[ActionBase],
|
|
25
|
+
num_stages: int
|
|
26
|
+
) -> bool:
|
|
27
|
+
match action:
|
|
28
|
+
case ForwardComputeAction():
|
|
29
|
+
if action.stage_idx == 0:
|
|
30
|
+
return True
|
|
31
|
+
if ForwardReceiveAction(action.stage_idx, action.microbatch_idx) in rank_events:
|
|
32
|
+
return True
|
|
33
|
+
if ForwardComputeAction(action.stage_idx - 1, action.microbatch_idx) in rank_events:
|
|
34
|
+
return True
|
|
35
|
+
return False
|
|
36
|
+
case BackwardFullInputComputeAction():
|
|
37
|
+
if action.stage_idx == num_stages - 1:
|
|
38
|
+
return True
|
|
39
|
+
if BackwardReceiveAction(action.stage_idx, action.microbatch_idx) in rank_events:
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
next_full = BackwardFullInputComputeAction(
|
|
43
|
+
action.stage_idx + 1,
|
|
44
|
+
action.microbatch_idx,
|
|
45
|
+
full_backward=True
|
|
46
|
+
)
|
|
47
|
+
next_inp = BackwardFullInputComputeAction(
|
|
48
|
+
action.stage_idx + 1,
|
|
49
|
+
action.microbatch_idx,
|
|
50
|
+
full_backward=False
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if next_full in rank_events or next_inp in rank_events:
|
|
54
|
+
return True
|
|
55
|
+
return False
|
|
56
|
+
case _:
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def check_action_communication_dependencies_fulfilled(
|
|
61
|
+
action: ActionBase,
|
|
62
|
+
rank_events: set[ActionBase],
|
|
63
|
+
num_stages: int
|
|
64
|
+
) -> bool:
|
|
65
|
+
"""
|
|
66
|
+
Checks if data dependencies (Receive or Local Compute) are met for an action.
|
|
67
|
+
|
|
68
|
+
This function determines if a compute action is allowed to run based on
|
|
69
|
+
whether its inputs are available in `rank_events`. Inputs are available
|
|
70
|
+
if they were either computed locally by a previous stage or received
|
|
71
|
+
from a remote rank.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
action: The action to check.
|
|
75
|
+
rank_events: A set of actions already completed on this rank.
|
|
76
|
+
num_stages: Total number of stages in the pipeline.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
True if all dependencies are satisfied, False otherwise.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
return all(
|
|
83
|
+
_check_action_communication_dependencies_fulfilled(sub, rank_events, num_stages)
|
|
84
|
+
for sub in _get_sub_actions(action)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclasses.dataclass(kw_only=True)
|
|
89
|
+
class _CommunicationPackage:
|
|
90
|
+
send: ActionBase
|
|
91
|
+
recv: ActionBase
|
|
92
|
+
sends_to_rank: int
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _create_communications_for_action(
|
|
96
|
+
action: ActionBase,
|
|
97
|
+
num_stages: int,
|
|
98
|
+
stage_to_rank: dict[int, int],
|
|
99
|
+
) -> _CommunicationPackage | None:
|
|
100
|
+
match action:
|
|
101
|
+
case ForwardComputeAction():
|
|
102
|
+
if action.stage_idx == num_stages - 1:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
curr_rank, next_rank = stage_to_rank[action.stage_idx], stage_to_rank[action.stage_idx + 1]
|
|
106
|
+
if curr_rank == next_rank:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return _CommunicationPackage(
|
|
110
|
+
send=ForwardSendAction(action.stage_idx, action.microbatch_idx),
|
|
111
|
+
recv=ForwardReceiveAction(action.stage_idx + 1, action.microbatch_idx),
|
|
112
|
+
sends_to_rank=next_rank
|
|
113
|
+
)
|
|
114
|
+
case BackwardFullInputComputeAction():
|
|
115
|
+
if action.stage_idx == 0:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
curr_rank, prev_rank = stage_to_rank[action.stage_idx], stage_to_rank[action.stage_idx - 1]
|
|
119
|
+
if curr_rank == prev_rank:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
return _CommunicationPackage(
|
|
123
|
+
send=BackwardSendAction(action.stage_idx, action.microbatch_idx),
|
|
124
|
+
recv=BackwardReceiveAction(action.stage_idx - 1, action.microbatch_idx),
|
|
125
|
+
sends_to_rank=prev_rank
|
|
126
|
+
)
|
|
127
|
+
case _:
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def add_communication_ops(
|
|
132
|
+
compute_actions: dict[int, list[ActionBase]],
|
|
133
|
+
stage_to_rank: dict[int, int],
|
|
134
|
+
num_stages: int,
|
|
135
|
+
) -> dict[int, list[ActionBase]]:
|
|
136
|
+
"""
|
|
137
|
+
Injects communication actions into a computation-only schedule.
|
|
138
|
+
|
|
139
|
+
This function iterates through the provided compute schedule and simulates execution.
|
|
140
|
+
When a compute action produces a result needed by a different rank, it injects
|
|
141
|
+
Send/Receive pairs. It also reorders actions to ensure that Receive
|
|
142
|
+
operations occur before the Computes that depend on them, preventing deadlocks.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
compute_actions: Initial schedule containing only compute operations.
|
|
146
|
+
stage_to_rank: Mapping from stage index to rank index.
|
|
147
|
+
num_stages: Total number of pipeline stages.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A new schedule dictionary including both compute and communication actions.
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
RuntimeError: If the schedule simulation enters a deadlock state.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
compute_actions = copy.deepcopy(compute_actions)
|
|
157
|
+
|
|
158
|
+
full_actions: dict[int, list[ActionBase]] = {rank: [] for rank in compute_actions}
|
|
159
|
+
completed_events: dict[int, set[ActionBase]] = {rank: set() for rank in compute_actions}
|
|
160
|
+
|
|
161
|
+
while compute_actions:
|
|
162
|
+
progress = False
|
|
163
|
+
|
|
164
|
+
for rank in sorted(compute_actions.keys()):
|
|
165
|
+
if not compute_actions[rank]:
|
|
166
|
+
del compute_actions[rank]
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
current_action = compute_actions[rank][0]
|
|
170
|
+
sub_actions = _get_sub_actions(current_action)
|
|
171
|
+
|
|
172
|
+
# Check readiness
|
|
173
|
+
if not check_action_communication_dependencies_fulfilled(
|
|
174
|
+
current_action, completed_events[rank], num_stages
|
|
175
|
+
):
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
# Execute
|
|
179
|
+
full_actions[rank].append(current_action)
|
|
180
|
+
compute_actions[rank].pop(0)
|
|
181
|
+
progress = True
|
|
182
|
+
|
|
183
|
+
for sub_action in sub_actions:
|
|
184
|
+
completed_events[rank].add(sub_action)
|
|
185
|
+
|
|
186
|
+
comm_pkg = _create_communications_for_action(
|
|
187
|
+
sub_action,
|
|
188
|
+
num_stages=num_stages,
|
|
189
|
+
stage_to_rank=stage_to_rank
|
|
190
|
+
)
|
|
191
|
+
if comm_pkg:
|
|
192
|
+
# Add Send locally
|
|
193
|
+
full_actions[rank].append(comm_pkg.send)
|
|
194
|
+
completed_events[rank].add(comm_pkg.send)
|
|
195
|
+
|
|
196
|
+
# Add Recv remotely and unblock target
|
|
197
|
+
full_actions[comm_pkg.sends_to_rank].append(comm_pkg.recv)
|
|
198
|
+
completed_events[comm_pkg.sends_to_rank].add(comm_pkg.recv)
|
|
199
|
+
|
|
200
|
+
if not progress and compute_actions:
|
|
201
|
+
raise RuntimeError("Deadlock in schedule simulation")
|
|
202
|
+
|
|
203
|
+
return full_actions
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from enum import StrEnum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ScheduleStyle(StrEnum):
|
|
6
|
+
"""
|
|
7
|
+
Defines the strategy for mapping logical stages to physical ranks.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
loop: Assigns stages in a round-robin circular fashion (mod pp_size).
|
|
11
|
+
v: Assigns stages in a zig-zag V-shape pattern. Useful for interleaved 1F1B schedules.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
loop = "loop"
|
|
15
|
+
v = "v"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def build_stage_to_host_rank_topology(
|
|
19
|
+
pp_size: int, num_stages: int, style: ScheduleStyle
|
|
20
|
+
) -> dict[int, int]:
|
|
21
|
+
"""
|
|
22
|
+
Constructs the mapping from stage index to rank index.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pp_size: Number of pipeline parallel ranks.
|
|
26
|
+
num_stages: Total number of model stages.
|
|
27
|
+
style: The topology style to use for assignment.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A dictionary mapping stage IDs to Rank IDs.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If the style is unknown or if V-style parameters are invalid
|
|
34
|
+
(num_stages must be divisible by pp_size).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
match style:
|
|
38
|
+
case ScheduleStyle.loop:
|
|
39
|
+
return {stage_index: stage_index % pp_size for stage_index in range(num_stages)}
|
|
40
|
+
case ScheduleStyle.v:
|
|
41
|
+
if num_stages % pp_size != 0:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
result = {}
|
|
47
|
+
rank_index = 0
|
|
48
|
+
for stage_index in range(num_stages):
|
|
49
|
+
result[stage_index] = rank_index
|
|
50
|
+
if (stage_index + 1) % pp_size == 0:
|
|
51
|
+
continue
|
|
52
|
+
if (stage_index // pp_size) % 2 == 0:
|
|
53
|
+
rank_index += 1
|
|
54
|
+
else:
|
|
55
|
+
rank_index -= 1
|
|
56
|
+
return result
|
|
57
|
+
case _:
|
|
58
|
+
raise ValueError()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def invert_stage_to_host_rank_topology(
|
|
62
|
+
stage_to_host: dict[int, int]
|
|
63
|
+
) -> dict[int, list[int]]:
|
|
64
|
+
"""
|
|
65
|
+
Inverts the topology mapping to list execution stages per rank.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
stage_to_host: Mapping from stage index to rank index.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
A dictionary where keys are Rank IDs and values are lists of Stage IDs
|
|
72
|
+
managed by that rank.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
host_to_stage = defaultdict(list)
|
|
76
|
+
for stage_idx, host in stage_to_host.items():
|
|
77
|
+
host_to_stage[host].append(stage_idx)
|
|
78
|
+
return dict(host_to_stage)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipelining Runtime Package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .action import (
|
|
6
|
+
ActionBase,
|
|
7
|
+
BackwardFullInputComputeAction,
|
|
8
|
+
BackwardReceiveAction,
|
|
9
|
+
BackwardSendAction,
|
|
10
|
+
BackwardWeightComputeAction,
|
|
11
|
+
ComposeAction,
|
|
12
|
+
ForwardComputeAction,
|
|
13
|
+
ForwardReceiveAction,
|
|
14
|
+
ForwardSendAction,
|
|
15
|
+
)
|
|
16
|
+
from .executor import PipelineScheduleExecutor
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"ActionBase",
|
|
20
|
+
"BackwardFullInputComputeAction",
|
|
21
|
+
"BackwardReceiveAction",
|
|
22
|
+
"BackwardSendAction",
|
|
23
|
+
"BackwardWeightComputeAction",
|
|
24
|
+
"ComposeAction",
|
|
25
|
+
"ForwardComputeAction",
|
|
26
|
+
"ForwardReceiveAction",
|
|
27
|
+
"ForwardSendAction",
|
|
28
|
+
"PipelineScheduleExecutor",
|
|
29
|
+
]
|