d9d 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/kernel/swiglu/op.py +17 -4
- d9d/loop/component/__init__.py +5 -2
- d9d/loop/component/gradient_manager.py +5 -1
- d9d/loop/component/model_stage_factory.py +14 -54
- d9d/loop/component/{loss_computer.py → pipeline_result_processing.py} +73 -10
- d9d/loop/component/{train_task_operator.py → task_operator.py} +97 -57
- d9d/loop/config/config.py +1 -1
- d9d/loop/control/task.py +2 -2
- d9d/loop/run/__init__.py +3 -0
- d9d/loop/run/inference.py +256 -0
- d9d/loop/run/train.py +2 -4
- d9d/loop/state.py +4 -1
- d9d/pipelining/api/__init__.py +3 -0
- d9d/pipelining/api/types.py +28 -0
- d9d/pipelining/factory/factory.py +68 -29
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +3 -1
- d9d/pipelining/infra/schedule/component/runtime/action.py +7 -10
- d9d/pipelining/infra/schedule/component/runtime/{loss.py → callback.py} +33 -10
- d9d/pipelining/infra/schedule/component/runtime/executor.py +10 -8
- d9d/pipelining/infra/schedule/component/runtime/offline.py +70 -0
- {d9d-0.1.1.dist-info → d9d-0.2.1.dist-info}/METADATA +2 -1
- {d9d-0.1.1.dist-info → d9d-0.2.1.dist-info}/RECORD +24 -20
- d9d-0.2.1.dist-info/licenses/LICENSE +201 -0
- {d9d-0.1.1.dist-info → d9d-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from d9d.core.dist_context import DeviceMeshParameters
|
|
6
|
+
from d9d.internals.determinism import set_seeds
|
|
7
|
+
from d9d.internals.pipeline_state import PipelineStateHandler
|
|
8
|
+
from d9d.loop.component import (
|
|
9
|
+
BatchMaths,
|
|
10
|
+
DataLoaderFactory,
|
|
11
|
+
InferenceProcessor,
|
|
12
|
+
InferenceTaskOperator,
|
|
13
|
+
JobProfiler,
|
|
14
|
+
ManualGarbageCollector,
|
|
15
|
+
ModelStageFactory,
|
|
16
|
+
StateCheckpointer,
|
|
17
|
+
Stepper,
|
|
18
|
+
TimeoutManager,
|
|
19
|
+
)
|
|
20
|
+
from d9d.loop.config import InferenceConfig, PipeliningConfig
|
|
21
|
+
from d9d.loop.control import (
|
|
22
|
+
DatasetProvider,
|
|
23
|
+
FinalizeContext,
|
|
24
|
+
InferenceTaskProvider,
|
|
25
|
+
InferenceTaskProviderContext,
|
|
26
|
+
ModelProvider,
|
|
27
|
+
)
|
|
28
|
+
from d9d.loop.state import InferenceJobState
|
|
29
|
+
from d9d.pipelining.factory import PipelineScheduleInferenceConfig
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class InferenceConfigurator:
|
|
33
|
+
"""
|
|
34
|
+
Orchestrates the assembly of the distributed inference environment.
|
|
35
|
+
|
|
36
|
+
This class binds the infrastructure configuration (DeviceMesh), the inference
|
|
37
|
+
parameters, and the user-defined logic (Providers) to create a fully
|
|
38
|
+
initialized state object capable of running the inference loop.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
mesh: DeviceMeshParameters,
|
|
44
|
+
parameters: InferenceConfig,
|
|
45
|
+
task_provider: InferenceTaskProvider,
|
|
46
|
+
model_provider: ModelProvider,
|
|
47
|
+
data_provider: DatasetProvider
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Constructs a configurator capable of building the full inference state.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
mesh: Definition of the distributed device mesh topology.
|
|
54
|
+
parameters: The global configuration object for inference.
|
|
55
|
+
task_provider: Factory for creating the inference task logic.
|
|
56
|
+
model_provider: Factory for defining and creating model stages.
|
|
57
|
+
data_provider: Factory for providing inference datasets.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
self._mesh = mesh
|
|
61
|
+
self._parameters = parameters
|
|
62
|
+
self._task_provider = task_provider
|
|
63
|
+
self._model_provider = model_provider
|
|
64
|
+
self._data_provider = data_provider
|
|
65
|
+
|
|
66
|
+
def _build_new_state(self) -> InferenceJobState:
|
|
67
|
+
dist_context = self._mesh.build()
|
|
68
|
+
|
|
69
|
+
pipelining_config = PipeliningConfig(
|
|
70
|
+
schedule=PipelineScheduleInferenceConfig()
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
set_seeds(dist_context, seed=self._parameters.determinism.base_seed)
|
|
74
|
+
|
|
75
|
+
task = self._task_provider(InferenceTaskProviderContext(
|
|
76
|
+
dist_context=dist_context
|
|
77
|
+
))
|
|
78
|
+
|
|
79
|
+
batch_maths = BatchMaths(
|
|
80
|
+
dist_context=dist_context,
|
|
81
|
+
config_batching=self._parameters.batching,
|
|
82
|
+
config_pipelining=pipelining_config
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
data_loader_factory = DataLoaderFactory(
|
|
86
|
+
dist_context=dist_context,
|
|
87
|
+
provider=self._data_provider,
|
|
88
|
+
config_data_loading=self._parameters.data_loading,
|
|
89
|
+
batch_maths=batch_maths
|
|
90
|
+
)
|
|
91
|
+
data_loader_infer = data_loader_factory.build_dataloader_for_infer_job()
|
|
92
|
+
|
|
93
|
+
stepper = Stepper(
|
|
94
|
+
initial_step=1,
|
|
95
|
+
total_steps=len(data_loader_infer)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
pipeline_state_handler = PipelineStateHandler(
|
|
99
|
+
sharding_spec={},
|
|
100
|
+
num_shards=batch_maths.num_microbatches_pipelining
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
processor = InferenceProcessor(
|
|
104
|
+
state=pipeline_state_handler,
|
|
105
|
+
task=task
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
schedule, modules = ModelStageFactory(
|
|
109
|
+
model_provider=self._model_provider,
|
|
110
|
+
dist_context=dist_context,
|
|
111
|
+
config_model=self._parameters.model_stage_factory,
|
|
112
|
+
config_pipelining=pipelining_config,
|
|
113
|
+
batch_maths=batch_maths,
|
|
114
|
+
pipeline_callback=processor
|
|
115
|
+
).build_pipeline_and_modules()
|
|
116
|
+
|
|
117
|
+
task_operator = InferenceTaskOperator(
|
|
118
|
+
dist_context=dist_context,
|
|
119
|
+
task=task,
|
|
120
|
+
pipeline=schedule,
|
|
121
|
+
pipeline_state=pipeline_state_handler
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
gc = ManualGarbageCollector(
|
|
125
|
+
dist_ctx=dist_context,
|
|
126
|
+
config=self._parameters.gc,
|
|
127
|
+
step=stepper
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
checkpointer = StateCheckpointer(
|
|
131
|
+
dist_context=dist_context,
|
|
132
|
+
stepper=stepper,
|
|
133
|
+
config=self._parameters.checkpointing,
|
|
134
|
+
gc=gc,
|
|
135
|
+
run_name=None
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
profiler = JobProfiler(
|
|
139
|
+
dist_context=dist_context,
|
|
140
|
+
stepper=stepper,
|
|
141
|
+
config=self._parameters.profiling
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
timeout_manager = TimeoutManager(
|
|
145
|
+
dist_context=dist_context,
|
|
146
|
+
config=self._parameters.timeout
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return InferenceJobState(
|
|
150
|
+
dist_context=dist_context,
|
|
151
|
+
data_loader=data_loader_infer,
|
|
152
|
+
stepper=stepper,
|
|
153
|
+
tracked_modules=modules,
|
|
154
|
+
garbage_collector=gc,
|
|
155
|
+
batch_maths=batch_maths,
|
|
156
|
+
checkpointer=checkpointer,
|
|
157
|
+
task=task,
|
|
158
|
+
profiler=profiler,
|
|
159
|
+
timeout_manager=timeout_manager,
|
|
160
|
+
task_operator=task_operator
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def configure(self) -> "Inference":
|
|
164
|
+
"""
|
|
165
|
+
Instantiates all inference components and returns a configured Inference engine.
|
|
166
|
+
|
|
167
|
+
This method triggers the creation of the distributed context, sets seeds,
|
|
168
|
+
builds the model, data loaders, and attaches all auxiliary components.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Inference: A ready-to-use inference engine instance encapsulating the job state.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
state = self._build_new_state()
|
|
175
|
+
|
|
176
|
+
return Inference(state)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class Inference:
|
|
180
|
+
"""
|
|
181
|
+
The main execution engine for running a distributed inference job.
|
|
182
|
+
|
|
183
|
+
This class manages the inference loop, lifecycle events, distributed synchronization,
|
|
184
|
+
and periodic side-effects (profiling, checkpointing). It ensures the model is in
|
|
185
|
+
evaluation mode and runs within a `torch.inference_mode` context.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __init__(self, state: InferenceJobState):
|
|
189
|
+
"""
|
|
190
|
+
Constructs an Inference engine from a pre-built job state.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
state: The encapsulated state object containing all initialized components.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
self._state = state
|
|
197
|
+
|
|
198
|
+
def _enable_eval_mode(self):
|
|
199
|
+
for module in self._state.tracked_modules.modules:
|
|
200
|
+
module.eval()
|
|
201
|
+
|
|
202
|
+
def infer(self):
|
|
203
|
+
"""
|
|
204
|
+
Executes the full inference workflow.
|
|
205
|
+
|
|
206
|
+
This method:
|
|
207
|
+
|
|
208
|
+
1. Waits for world synchronization.
|
|
209
|
+
2. Loads the latest checkpoint if available.
|
|
210
|
+
3. Iterates through the data loader.
|
|
211
|
+
4. Executes the pipeline forward pass for every batch.
|
|
212
|
+
5. Handles periodic garbage collection and profiling.
|
|
213
|
+
6. Finalizes the task upon completion.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
with torch.inference_mode():
|
|
217
|
+
self._enable_eval_mode()
|
|
218
|
+
|
|
219
|
+
self._state.dist_context.logger.info("Waiting for the world to start job")
|
|
220
|
+
self._state.dist_context.wait_world()
|
|
221
|
+
self._state.dist_context.logger.info("Trying to load last checkpoint before doing anything else")
|
|
222
|
+
self._state.checkpointer.load_last_checkpoint(self._state)
|
|
223
|
+
|
|
224
|
+
if self._state.stepper.current_step >= self._state.stepper.total_steps:
|
|
225
|
+
self._state.dist_context.logger.info("Already ran, will do nothing")
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
self._state.dist_context.wait_world()
|
|
229
|
+
|
|
230
|
+
with (
|
|
231
|
+
tqdm(
|
|
232
|
+
desc="Inference",
|
|
233
|
+
total=self._state.stepper.total_steps,
|
|
234
|
+
disable=not self._state.dist_context.is_local_main_process,
|
|
235
|
+
initial=self._state.stepper.current_step
|
|
236
|
+
) as bar,
|
|
237
|
+
self._state.garbage_collector as gc,
|
|
238
|
+
self._state.profiler.open() as profiler
|
|
239
|
+
):
|
|
240
|
+
self._state.timeout_manager.step()
|
|
241
|
+
|
|
242
|
+
for batch_group in self._state.data_loader:
|
|
243
|
+
for batch in batch_group:
|
|
244
|
+
self._state.task_operator.forward(batch)
|
|
245
|
+
|
|
246
|
+
gc.collect_periodic()
|
|
247
|
+
self._state.stepper.step()
|
|
248
|
+
bar.update()
|
|
249
|
+
|
|
250
|
+
# checkpoint at the end of the step
|
|
251
|
+
self._state.checkpointer.checkpoint_if_needed(self._state)
|
|
252
|
+
|
|
253
|
+
if profiler:
|
|
254
|
+
profiler.step()
|
|
255
|
+
|
|
256
|
+
self._state.task.finalize(FinalizeContext())
|
d9d/loop/run/train.py
CHANGED
|
@@ -121,7 +121,7 @@ class TrainingConfigurator:
|
|
|
121
121
|
config_model=self._parameters.model_stage_factory,
|
|
122
122
|
config_pipelining=self._parameters.pipelining,
|
|
123
123
|
batch_maths=batch_maths,
|
|
124
|
-
|
|
124
|
+
pipeline_callback=loss_computer
|
|
125
125
|
).build_pipeline_and_modules()
|
|
126
126
|
|
|
127
127
|
metrics = ComposeMetric(task.create_metrics(CreateMetricsContext()).metrics)
|
|
@@ -130,9 +130,7 @@ class TrainingConfigurator:
|
|
|
130
130
|
task_operator = TrainTaskOperator(
|
|
131
131
|
dist_context=dist_context,
|
|
132
132
|
task=task,
|
|
133
|
-
|
|
134
|
-
tracked_modules=modules,
|
|
135
|
-
loss_computer=loss_computer,
|
|
133
|
+
pipeline=schedule,
|
|
136
134
|
pipeline_state=pipeline_state_handler,
|
|
137
135
|
metrics=metrics
|
|
138
136
|
)
|
d9d/loop/state.py
CHANGED
|
@@ -10,6 +10,7 @@ from d9d.loop.component import (
|
|
|
10
10
|
BatchMaths,
|
|
11
11
|
GradientClipper,
|
|
12
12
|
GradientManager,
|
|
13
|
+
InferenceTaskOperator,
|
|
13
14
|
JobLogger,
|
|
14
15
|
JobProfiler,
|
|
15
16
|
ManualGarbageCollector,
|
|
@@ -123,14 +124,16 @@ class TrainJobState(JobState):
|
|
|
123
124
|
|
|
124
125
|
|
|
125
126
|
@dataclasses.dataclass(kw_only=True)
|
|
126
|
-
class
|
|
127
|
+
class InferenceJobState(JobState):
|
|
127
128
|
"""
|
|
128
129
|
Container for the state of an inference job.
|
|
129
130
|
|
|
130
131
|
Attributes:
|
|
131
132
|
task: The specific inference task logic definition.
|
|
133
|
+
task_operator: Executor for running forward and backward passes.
|
|
132
134
|
"""
|
|
133
135
|
task: InferenceTask
|
|
136
|
+
task_operator: InferenceTaskOperator
|
|
134
137
|
|
|
135
138
|
def state_dict(self) -> dict[str, Any]:
|
|
136
139
|
return {
|
d9d/pipelining/api/__init__.py
CHANGED
|
@@ -9,9 +9,12 @@ from .module import (
|
|
|
9
9
|
)
|
|
10
10
|
from .schedule import PipelineSchedule
|
|
11
11
|
from .sharding import PipelineShardingSpec
|
|
12
|
+
from .types import PipelineLossFn, PipelineResultFn
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"ModuleSupportsPipelining",
|
|
16
|
+
"PipelineLossFn",
|
|
17
|
+
"PipelineResultFn",
|
|
15
18
|
"PipelineSchedule",
|
|
16
19
|
"PipelineShardingSpec",
|
|
17
20
|
"PipelineStageInfo",
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
PipelineResultFn = Callable[[dict[str, torch.Tensor], int], Any]
|
|
7
|
+
"""
|
|
8
|
+
Callback function type for handling results from a final pipeline stage.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
outputs: A dictionary mapping output names to tensors produced by the stage.
|
|
12
|
+
microbatch_idx: The index of the current micro-batch being processed.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Anything - not used.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
PipelineLossFn = Callable[[dict[str, torch.Tensor], int], torch.Tensor]
|
|
19
|
+
"""
|
|
20
|
+
Callback function type for calculating loss in the final pipeline stage.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
outputs: A dictionary mapping output names to tensors produced by the model.
|
|
24
|
+
microbatch_idx: The index of the current micro-batch being processed.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The computed loss tensor (scalar).
|
|
28
|
+
"""
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
|
|
4
|
-
import torch
|
|
5
4
|
from torch import nn
|
|
6
5
|
|
|
7
6
|
from ...core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
8
|
-
from ..api import PipelineSchedule, PipelineStageInfo
|
|
7
|
+
from ..api import PipelineLossFn, PipelineResultFn, PipelineSchedule, PipelineStageInfo
|
|
9
8
|
from ..infra.schedule.component.program import (
|
|
10
9
|
build_stage_to_host_rank_topology,
|
|
11
10
|
invert_stage_to_host_rank_topology,
|
|
12
11
|
)
|
|
13
|
-
from ..infra.schedule.component.runtime import PipelineScheduleExecutor
|
|
12
|
+
from ..infra.schedule.component.runtime import OfflinePipelineExecutor, PipelineScheduleExecutor
|
|
14
13
|
from ..infra.stage import PipelineStage
|
|
15
14
|
from .config import (
|
|
16
15
|
AnyPipelineScheduleConfig,
|
|
16
|
+
PipelineScheduleInferenceConfig,
|
|
17
17
|
)
|
|
18
18
|
from .registry import PIPELINE_PROGRAM_REGISTRY
|
|
19
19
|
|
|
@@ -27,38 +27,31 @@ class PipelineScheduleInfo:
|
|
|
27
27
|
has_last_stage: bool
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def
|
|
31
|
-
dist_context: DistributedContext,
|
|
32
|
-
n_microbatches: int,
|
|
30
|
+
def _build_schedule_local(
|
|
33
31
|
schedule_config: AnyPipelineScheduleConfig,
|
|
34
32
|
model_provider: Callable[[PipelineStageInfo], nn.Module],
|
|
35
|
-
|
|
33
|
+
callback: PipelineLossFn | PipelineResultFn
|
|
36
34
|
) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
|
|
37
|
-
|
|
38
|
-
Constructs the pipeline schedule and instantiates model stages.
|
|
35
|
+
stage_info = PipelineStageInfo(num_stages=1, current_stage=0)
|
|
39
36
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
3. Instantiates the local model stages for the current rank using `model_provider`.
|
|
44
|
-
4. Wraps models in `PipelineStage` containers.
|
|
45
|
-
5. Generates the execution program (action list).
|
|
46
|
-
6. Builds the runtime executor.
|
|
37
|
+
model = model_provider(stage_info)
|
|
38
|
+
has_backward = not isinstance(schedule_config, PipelineScheduleInferenceConfig)
|
|
39
|
+
scheduler = OfflinePipelineExecutor(model=model, callback=callback, do_backward=has_backward)
|
|
47
40
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
for that specific stage.
|
|
54
|
-
loss_fn: Optional loss function. Required if training (backward pass needed).
|
|
41
|
+
return PipelineScheduleInfo(
|
|
42
|
+
schedule=scheduler,
|
|
43
|
+
has_first_stage=True,
|
|
44
|
+
has_last_stage=True
|
|
45
|
+
), [model]
|
|
55
46
|
|
|
56
|
-
Returns:
|
|
57
|
-
A tuple containing:
|
|
58
|
-
1. `PipelineScheduleInfo`: The executable schedule and metadata.
|
|
59
|
-
2. `list[nn.Module]`: The local PyTorch modules created for this rank.
|
|
60
|
-
"""
|
|
61
47
|
|
|
48
|
+
def _build_schedule_distributed(
|
|
49
|
+
dist_context: DistributedContext,
|
|
50
|
+
n_microbatches: int,
|
|
51
|
+
schedule_config: AnyPipelineScheduleConfig,
|
|
52
|
+
model_provider: Callable[[PipelineStageInfo], nn.Module],
|
|
53
|
+
callback: PipelineLossFn | PipelineResultFn,
|
|
54
|
+
) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
|
|
62
55
|
program_builder = PIPELINE_PROGRAM_REGISTRY.program_for(schedule_config)
|
|
63
56
|
mesh = dist_context.mesh_for(REGULAR_DOMAIN)["pp"]
|
|
64
57
|
|
|
@@ -103,7 +96,7 @@ def build_schedule(
|
|
|
103
96
|
dist_context=dist_context,
|
|
104
97
|
stages=stages,
|
|
105
98
|
num_microbatches=n_microbatches,
|
|
106
|
-
|
|
99
|
+
callback=callback,
|
|
107
100
|
program=program
|
|
108
101
|
)
|
|
109
102
|
|
|
@@ -112,3 +105,49 @@ def build_schedule(
|
|
|
112
105
|
has_first_stage=has_first_stage,
|
|
113
106
|
has_last_stage=has_last_stage
|
|
114
107
|
), modules
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def build_schedule(
|
|
111
|
+
dist_context: DistributedContext,
|
|
112
|
+
n_microbatches: int,
|
|
113
|
+
schedule_config: AnyPipelineScheduleConfig,
|
|
114
|
+
model_provider: Callable[[PipelineStageInfo], nn.Module],
|
|
115
|
+
callback: PipelineLossFn | PipelineResultFn,
|
|
116
|
+
) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
|
|
117
|
+
"""
|
|
118
|
+
Constructs the pipeline schedule and instantiates model stages.
|
|
119
|
+
|
|
120
|
+
This function coordinates the creation of the pipeline. If the context is
|
|
121
|
+
distributed, it builds a parallel schedule (`PipelineScheduleExecutor`) by
|
|
122
|
+
calculating topology and creating stages for the current rank. If the
|
|
123
|
+
context is local, it builds an offline schedule (`OfflinePipelineExecutor`)
|
|
124
|
+
for direct execution.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
dist_context: The distributed context.
|
|
128
|
+
n_microbatches: Number of microbatches per global step.
|
|
129
|
+
schedule_config: Configuration object determining the schedule strategy.
|
|
130
|
+
model_provider: A factory function that accepts stage info and returns an `nn.Module`
|
|
131
|
+
for that specific stage.
|
|
132
|
+
callback: Callback either computing loss function (if training) or just processing pipeline outputs
|
|
133
|
+
(if not training).
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A tuple containing the schedule info (executor and metadata) and a list
|
|
137
|
+
of local PyTorch modules created for this rank.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
if dist_context.mesh_params.is_distributed:
|
|
141
|
+
return _build_schedule_distributed(
|
|
142
|
+
dist_context=dist_context,
|
|
143
|
+
n_microbatches=n_microbatches,
|
|
144
|
+
schedule_config=schedule_config,
|
|
145
|
+
model_provider=model_provider,
|
|
146
|
+
callback=callback
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
return _build_schedule_local(
|
|
150
|
+
schedule_config=schedule_config,
|
|
151
|
+
model_provider=model_provider,
|
|
152
|
+
callback=callback
|
|
153
|
+
)
|
|
@@ -14,6 +14,7 @@ from .action import (
|
|
|
14
14
|
ForwardSendAction,
|
|
15
15
|
)
|
|
16
16
|
from .executor import PipelineScheduleExecutor
|
|
17
|
+
from .offline import OfflinePipelineExecutor
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
"ActionBase",
|
|
@@ -25,5 +26,6 @@ __all__ = [
|
|
|
25
26
|
"ForwardComputeAction",
|
|
26
27
|
"ForwardReceiveAction",
|
|
27
28
|
"ForwardSendAction",
|
|
28
|
-
"
|
|
29
|
+
"OfflinePipelineExecutor",
|
|
30
|
+
"PipelineScheduleExecutor"
|
|
29
31
|
]
|
|
@@ -7,8 +7,8 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from d9d.pipelining.infra.stage import PipelineStage
|
|
9
9
|
|
|
10
|
+
from .callback import PipelineLossHandler, PipelineResultHandler
|
|
10
11
|
from .communications import PipelineCommunicationHandler
|
|
11
|
-
from .loss import PipelineLossHandler
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
@@ -21,7 +21,7 @@ class ActionContext:
|
|
|
21
21
|
pipeline_kwargs_microbatches: The global keyword arguments sharded by microbatch.
|
|
22
22
|
stages: A mapping of stage indices to their active PipelineStage instances.
|
|
23
23
|
communications: The handler for P2P communications.
|
|
24
|
-
|
|
24
|
+
callback: The handler for either loss computation or result processing.
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
27
|
pipeline_inputs_microbatches: tuple[dict[str, torch.Tensor], ...]
|
|
@@ -29,7 +29,7 @@ class ActionContext:
|
|
|
29
29
|
|
|
30
30
|
stages: dict[int, PipelineStage]
|
|
31
31
|
communications: PipelineCommunicationHandler
|
|
32
|
-
|
|
32
|
+
callback: PipelineLossHandler | PipelineResultHandler
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class ActionWorkType(StrEnum):
|
|
@@ -208,7 +208,6 @@ class ForwardComputeAction(ActionBase):
|
|
|
208
208
|
microbatch_idx: int
|
|
209
209
|
|
|
210
210
|
def apply(self, ctx: ActionContext):
|
|
211
|
-
# todo check unsharded
|
|
212
211
|
stage = ctx.stages[self.stage_idx]
|
|
213
212
|
|
|
214
213
|
if not stage.info.is_current_stage_first and self.stage_idx - 1 not in ctx.stages:
|
|
@@ -221,8 +220,8 @@ class ForwardComputeAction(ActionBase):
|
|
|
221
220
|
)
|
|
222
221
|
result = stage.get_local_fwd_output(self.microbatch_idx)
|
|
223
222
|
|
|
224
|
-
if stage.info.is_current_stage_last
|
|
225
|
-
ctx.
|
|
223
|
+
if stage.info.is_current_stage_last:
|
|
224
|
+
ctx.callback.trigger(result, self.microbatch_idx)
|
|
226
225
|
|
|
227
226
|
if not stage.info.is_current_stage_last and self.stage_idx + 1 in ctx.stages:
|
|
228
227
|
ctx.stages[self.stage_idx + 1].set_local_fwd_input(
|
|
@@ -260,14 +259,13 @@ class BackwardFullInputComputeAction(ActionBase):
|
|
|
260
259
|
full_backward: bool
|
|
261
260
|
|
|
262
261
|
def apply(self, ctx: ActionContext):
|
|
263
|
-
# todo unshard
|
|
264
262
|
stage = ctx.stages[self.stage_idx]
|
|
265
263
|
|
|
266
264
|
if not stage.info.is_current_stage_last and self.stage_idx + 1 not in ctx.stages:
|
|
267
265
|
ctx.communications.wait_bwd_recv(self.stage_idx, self.microbatch_idx)
|
|
268
266
|
|
|
269
|
-
if stage.info.is_current_stage_last and ctx.
|
|
270
|
-
loss = ctx.
|
|
267
|
+
if stage.info.is_current_stage_last and isinstance(ctx.callback, PipelineLossHandler):
|
|
268
|
+
loss = ctx.callback.acquire_loss(self.microbatch_idx)
|
|
271
269
|
else:
|
|
272
270
|
loss = None
|
|
273
271
|
|
|
@@ -310,7 +308,6 @@ class BackwardWeightComputeAction(ActionBase):
|
|
|
310
308
|
microbatch_idx: int
|
|
311
309
|
|
|
312
310
|
def apply(self, ctx: ActionContext):
|
|
313
|
-
# todo unshard
|
|
314
311
|
stage = ctx.stages[self.stage_idx]
|
|
315
312
|
|
|
316
313
|
stage.backward_weight_one_chunk(
|
|
@@ -1,14 +1,41 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
|
-
|
|
3
|
+
from d9d.pipelining.api import PipelineLossFn, PipelineResultFn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PipelineResultHandler:
|
|
7
|
+
"""
|
|
8
|
+
Wraps a callback function to handle results from pipeline execution.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, callback_fn: PipelineResultFn):
|
|
12
|
+
"""
|
|
13
|
+
Constructs PipelineResultHandler object.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
callback_fn: The function called with results.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
self._callback_fn = callback_fn
|
|
20
|
+
|
|
21
|
+
def trigger(self, forward_result: dict[str, torch.Tensor], microbatch_index: int):
|
|
22
|
+
"""
|
|
23
|
+
Invokes the underlying callback with the provided results.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
forward_result: Dictionary of output tensors from the pipeline.
|
|
27
|
+
microbatch_index: The index of the current micro-batch.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
self._callback_fn(forward_result, microbatch_index)
|
|
6
31
|
|
|
7
32
|
|
|
8
33
|
class PipelineLossHandler:
|
|
9
|
-
"""
|
|
34
|
+
"""
|
|
35
|
+
Manages loss computation and state caching across forward and backward passes.
|
|
36
|
+
"""
|
|
10
37
|
|
|
11
|
-
def __init__(self, loss_fn:
|
|
38
|
+
def __init__(self, loss_fn: PipelineLossFn):
|
|
12
39
|
"""
|
|
13
40
|
Constructs the loss handler.
|
|
14
41
|
|
|
@@ -19,21 +46,17 @@ class PipelineLossHandler:
|
|
|
19
46
|
self._loss_fn = loss_fn
|
|
20
47
|
self._cached_values: dict[int, torch.Tensor] = {}
|
|
21
48
|
|
|
22
|
-
def
|
|
49
|
+
def trigger(self, forward_result: dict[str, torch.Tensor], microbatch_index: int):
|
|
23
50
|
"""
|
|
24
51
|
Computes loss for a given microbatch result and caches it.
|
|
25
52
|
|
|
26
53
|
Args:
|
|
27
54
|
forward_result: The output from the last stage of the model.
|
|
28
55
|
microbatch_index: The index of the microbatch being processed.
|
|
29
|
-
|
|
30
|
-
Returns:
|
|
31
|
-
The computed loss tensor.
|
|
32
56
|
"""
|
|
33
57
|
|
|
34
58
|
result = self._loss_fn(forward_result, microbatch_index)
|
|
35
59
|
self._cached_values[microbatch_index] = result
|
|
36
|
-
return result
|
|
37
60
|
|
|
38
61
|
def acquire_loss(self, microbatch_index: int) -> torch.Tensor:
|
|
39
62
|
"""
|