d9d 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/kernel/swiglu/op.py +17 -4
- d9d/loop/component/__init__.py +5 -2
- d9d/loop/component/model_stage_factory.py +14 -54
- d9d/loop/component/{loss_computer.py → pipeline_result_processing.py} +73 -10
- d9d/loop/component/{train_task_operator.py → task_operator.py} +97 -57
- d9d/loop/config/config.py +1 -1
- d9d/loop/control/task.py +2 -2
- d9d/loop/run/__init__.py +3 -0
- d9d/loop/run/inference.py +256 -0
- d9d/loop/run/train.py +2 -4
- d9d/loop/state.py +4 -1
- d9d/pipelining/api/__init__.py +3 -0
- d9d/pipelining/api/types.py +28 -0
- d9d/pipelining/factory/factory.py +68 -29
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +3 -1
- d9d/pipelining/infra/schedule/component/runtime/action.py +7 -10
- d9d/pipelining/infra/schedule/component/runtime/{loss.py → callback.py} +33 -10
- d9d/pipelining/infra/schedule/component/runtime/executor.py +10 -8
- d9d/pipelining/infra/schedule/component/runtime/offline.py +70 -0
- {d9d-0.1.0.dist-info → d9d-0.2.0.dist-info}/METADATA +22 -1
- {d9d-0.1.0.dist-info → d9d-0.2.0.dist-info}/RECORD +23 -19
- {d9d-0.1.0.dist-info → d9d-0.2.0.dist-info}/WHEEL +1 -1
- d9d-0.2.0.dist-info/licenses/LICENSE +201 -0
d9d/kernel/swiglu/op.py
CHANGED
|
@@ -3,6 +3,15 @@ import triton
|
|
|
3
3
|
import triton.language as tl
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
def _size_bucket(n_elements: int) -> int:
|
|
7
|
+
# different auto-tuning for small and asymptotically large kernels
|
|
8
|
+
# perhaps we could extend this in future?
|
|
9
|
+
if n_elements < 8192:
|
|
10
|
+
return 0
|
|
11
|
+
else:
|
|
12
|
+
return 1
|
|
13
|
+
|
|
14
|
+
|
|
6
15
|
@triton.autotune(
|
|
7
16
|
configs=[
|
|
8
17
|
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
|
|
@@ -11,7 +20,7 @@ import triton.language as tl
|
|
|
11
20
|
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
|
|
12
21
|
triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
|
|
13
22
|
],
|
|
14
|
-
key=["
|
|
23
|
+
key=["size_bucket"]
|
|
15
24
|
)
|
|
16
25
|
@triton.jit
|
|
17
26
|
def _silu_mul_kernel(
|
|
@@ -19,6 +28,7 @@ def _silu_mul_kernel(
|
|
|
19
28
|
y_ptr: torch.Tensor,
|
|
20
29
|
out_ptr: torch.Tensor,
|
|
21
30
|
n_elements: int,
|
|
31
|
+
size_bucket: int, # used for autotuning
|
|
22
32
|
BLOCK_SIZE: tl.constexpr,
|
|
23
33
|
):
|
|
24
34
|
# prepare
|
|
@@ -72,7 +82,8 @@ def silu_mul_forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
72
82
|
|
|
73
83
|
_silu_mul_kernel[_grid](
|
|
74
84
|
x, y, out,
|
|
75
|
-
n_elements
|
|
85
|
+
n_elements,
|
|
86
|
+
size_bucket=_size_bucket(n_elements)
|
|
76
87
|
)
|
|
77
88
|
|
|
78
89
|
return out
|
|
@@ -86,7 +97,7 @@ def silu_mul_forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
86
97
|
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
|
|
87
98
|
triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
|
|
88
99
|
],
|
|
89
|
-
key=["
|
|
100
|
+
key=["size_bucket"]
|
|
90
101
|
)
|
|
91
102
|
@triton.jit
|
|
92
103
|
def _silu_mul_backward_kernel(
|
|
@@ -96,6 +107,7 @@ def _silu_mul_backward_kernel(
|
|
|
96
107
|
grad_x_ptr: torch.Tensor,
|
|
97
108
|
grad_y_ptr: torch.Tensor,
|
|
98
109
|
n_elements: int,
|
|
110
|
+
size_bucket: int, # used for autotuning
|
|
99
111
|
BLOCK_SIZE: tl.constexpr
|
|
100
112
|
):
|
|
101
113
|
# prepare
|
|
@@ -161,7 +173,8 @@ def silu_mul_backward(
|
|
|
161
173
|
_silu_mul_backward_kernel[_grid](
|
|
162
174
|
grad_output, x, y,
|
|
163
175
|
grad_x, grad_y,
|
|
164
|
-
n_elements
|
|
176
|
+
n_elements,
|
|
177
|
+
size_bucket=_size_bucket(n_elements)
|
|
165
178
|
)
|
|
166
179
|
|
|
167
180
|
return grad_x, grad_y
|
d9d/loop/component/__init__.py
CHANGED
|
@@ -6,13 +6,13 @@ from .gradient_clipper import GradientClipper
|
|
|
6
6
|
from .gradient_manager import GradientManager
|
|
7
7
|
from .job_logger import JobLogger
|
|
8
8
|
from .job_profiler import JobProfiler
|
|
9
|
-
from .loss_computer import LossComputer
|
|
10
9
|
from .model_stage_exporter import ModelStageExporter
|
|
11
10
|
from .model_stage_factory import ModelStageFactory, TrackedModules
|
|
12
11
|
from .optimizer_factory import OptimizerFactory
|
|
12
|
+
from .pipeline_result_processing import InferenceProcessor, LossComputer, PipelineOutputsProcessor
|
|
13
13
|
from .stepper import Stepper
|
|
14
|
+
from .task_operator import ForwardResult, InferenceTaskOperator, TrainTaskOperator
|
|
14
15
|
from .timeout_manager import TimeoutManager
|
|
15
|
-
from .train_task_operator import ForwardResult, TrainTaskOperator
|
|
16
16
|
|
|
17
17
|
__all__ = [
|
|
18
18
|
"BatchMaths",
|
|
@@ -20,6 +20,8 @@ __all__ = [
|
|
|
20
20
|
"ForwardResult",
|
|
21
21
|
"GradientClipper",
|
|
22
22
|
"GradientManager",
|
|
23
|
+
"InferenceProcessor",
|
|
24
|
+
"InferenceTaskOperator",
|
|
23
25
|
"JobLogger",
|
|
24
26
|
"JobProfiler",
|
|
25
27
|
"LossComputer",
|
|
@@ -27,6 +29,7 @@ __all__ = [
|
|
|
27
29
|
"ModelStageExporter",
|
|
28
30
|
"ModelStageFactory",
|
|
29
31
|
"OptimizerFactory",
|
|
32
|
+
"PipelineOutputsProcessor",
|
|
30
33
|
"StateCheckpointer",
|
|
31
34
|
"Stepper",
|
|
32
35
|
"TimeoutManager",
|
|
@@ -15,7 +15,7 @@ from d9d.pipelining.api import PipelineStageInfo
|
|
|
15
15
|
from d9d.pipelining.factory.factory import PipelineScheduleInfo, build_schedule
|
|
16
16
|
|
|
17
17
|
from .batch_maths import BatchMaths
|
|
18
|
-
from .
|
|
18
|
+
from .pipeline_result_processing import PipelineOutputsProcessor
|
|
19
19
|
|
|
20
20
|
StatefulPredicate = Callable[[str, torch.Tensor], bool]
|
|
21
21
|
"""Determines if a specific parameter or buffer should be included in the state dictionary."""
|
|
@@ -51,28 +51,6 @@ class TrackedModules(Stateful):
|
|
|
51
51
|
self._modules = modules
|
|
52
52
|
self._stateful_predicate = stateful_predicate
|
|
53
53
|
|
|
54
|
-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
55
|
-
"""
|
|
56
|
-
Forwards execution to the only pipeline stage.
|
|
57
|
-
|
|
58
|
-
This method is only valid when pipeline parallelism is disabled.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
*args: Positional arguments passed to the module.
|
|
62
|
-
**kwargs: Keyword arguments passed to the module.
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
The output of the model execution.
|
|
66
|
-
|
|
67
|
-
Raises:
|
|
68
|
-
ValueError: If pipeline parallelism is configured.
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
if self._dist_context.mesh_params.has_pipeline_parallel:
|
|
72
|
-
raise ValueError("You cannot call tracked modules when using pipelining")
|
|
73
|
-
|
|
74
|
-
return self._modules[0](*args, **kwargs)
|
|
75
|
-
|
|
76
54
|
@property
|
|
77
55
|
def modules(self) -> list[nn.Module]:
|
|
78
56
|
"""Returns the list of underlying PyTorch model modules."""
|
|
@@ -159,8 +137,8 @@ class ModelStageFactory:
|
|
|
159
137
|
dist_context: DistributedContext,
|
|
160
138
|
batch_maths: BatchMaths,
|
|
161
139
|
config_model: ModelStageFactoryConfig,
|
|
162
|
-
config_pipelining: PipeliningConfig
|
|
163
|
-
|
|
140
|
+
config_pipelining: PipeliningConfig,
|
|
141
|
+
pipeline_callback: PipelineOutputsProcessor
|
|
164
142
|
):
|
|
165
143
|
"""Constructs a ModelStageFactory object."""
|
|
166
144
|
|
|
@@ -169,7 +147,7 @@ class ModelStageFactory:
|
|
|
169
147
|
self._config_model = config_model
|
|
170
148
|
self._config_pipelining = config_pipelining
|
|
171
149
|
self._batch_maths = batch_maths
|
|
172
|
-
self.
|
|
150
|
+
self._pipeline_callback = pipeline_callback
|
|
173
151
|
|
|
174
152
|
def _build_model_stage(self, stage: PipelineStageInfo) -> nn.Module:
|
|
175
153
|
# create a model with no real memory occupied
|
|
@@ -218,21 +196,13 @@ class ModelStageFactory:
|
|
|
218
196
|
|
|
219
197
|
def build_pipeline_and_modules(
|
|
220
198
|
self
|
|
221
|
-
) -> tuple[PipelineScheduleInfo
|
|
199
|
+
) -> tuple[PipelineScheduleInfo, TrackedModules]:
|
|
222
200
|
"""
|
|
223
201
|
Constructs the execution schedule and the model container.
|
|
224
202
|
|
|
225
|
-
If pipeline parallelism is enabled, this orchestrates the creation of a
|
|
226
|
-
distributed pipeline schedule.
|
|
227
|
-
|
|
228
|
-
Otherwise, it simply builds a standalone model stage.
|
|
229
|
-
|
|
230
203
|
Returns:
|
|
231
|
-
The pipeline schedule information
|
|
204
|
+
The pipeline schedule information.
|
|
232
205
|
The `TrackedModules` instance wrapping the created model stage(s).
|
|
233
|
-
|
|
234
|
-
Raises:
|
|
235
|
-
ValueError: If pipelining configuration is missing but a pipeline is requested.
|
|
236
206
|
"""
|
|
237
207
|
|
|
238
208
|
if self._config_model.checkpoint_only_trainable_parameters:
|
|
@@ -240,22 +210,12 @@ class ModelStageFactory:
|
|
|
240
210
|
else:
|
|
241
211
|
stateful_predicate = _stateful_predicate_always
|
|
242
212
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
dist_context=self._dist_context,
|
|
251
|
-
n_microbatches=self._batch_maths.num_microbatches_pipelining,
|
|
252
|
-
schedule_config=self._config_pipelining.schedule,
|
|
253
|
-
model_provider=self._build_model_stage,
|
|
254
|
-
loss_fn=loss_fn
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
return schedule, TrackedModules(self._dist_context, modules, stateful_predicate)
|
|
258
|
-
else:
|
|
259
|
-
model = self._build_model_stage(PipelineStageInfo(num_stages=1, current_stage=0))
|
|
213
|
+
schedule, modules = build_schedule(
|
|
214
|
+
dist_context=self._dist_context,
|
|
215
|
+
n_microbatches=self._batch_maths.num_microbatches_pipelining,
|
|
216
|
+
schedule_config=self._config_pipelining.schedule,
|
|
217
|
+
model_provider=self._build_model_stage,
|
|
218
|
+
callback=self._pipeline_callback
|
|
219
|
+
)
|
|
260
220
|
|
|
261
|
-
|
|
221
|
+
return schedule, TrackedModules(self._dist_context, modules, stateful_predicate)
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
|
|
3
6
|
from d9d.internals.pipeline_state import PipelineStateHandler
|
|
4
|
-
from d9d.loop.control import ComputeLossContext, TrainTask
|
|
7
|
+
from d9d.loop.control import ComputeLossContext, InferenceTask, ProcessOutputsContext, TrainTask
|
|
5
8
|
|
|
6
9
|
from .stepper import Stepper
|
|
7
10
|
|
|
@@ -9,7 +12,20 @@ STATE_LOSS = "__internal_loss"
|
|
|
9
12
|
STATE_LOSS_WEIGHT = "__internal_loss_weight"
|
|
10
13
|
|
|
11
14
|
|
|
12
|
-
|
|
15
|
+
TOutput = TypeVar("TOutput")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PipelineOutputsProcessor(abc.ABC, Generic[TOutput]):
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def __call__(
|
|
21
|
+
self,
|
|
22
|
+
pipeline_outputs: dict[str, torch.Tensor],
|
|
23
|
+
microbatch_idx: int
|
|
24
|
+
) -> TOutput:
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LossComputer(PipelineOutputsProcessor[torch.Tensor]):
|
|
13
29
|
"""
|
|
14
30
|
Handles the computation of loss values and their integration into the pipeline state.
|
|
15
31
|
|
|
@@ -38,10 +54,10 @@ class LossComputer:
|
|
|
38
54
|
self._task = task
|
|
39
55
|
self._stepper = stepper
|
|
40
56
|
|
|
41
|
-
def
|
|
57
|
+
def __call__(
|
|
42
58
|
self,
|
|
43
59
|
pipeline_outputs: dict[str, torch.Tensor],
|
|
44
|
-
microbatch_idx: int
|
|
60
|
+
microbatch_idx: int
|
|
45
61
|
) -> torch.Tensor:
|
|
46
62
|
"""
|
|
47
63
|
Computes the weighted loss for a specific sharded microbatch or the full microbatch.
|
|
@@ -61,12 +77,9 @@ class LossComputer:
|
|
|
61
77
|
The calculated loss multiplied by its weight.
|
|
62
78
|
"""
|
|
63
79
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
state = self._state.sharded_state(
|
|
68
|
-
shard_id=microbatch_idx
|
|
69
|
-
)
|
|
80
|
+
state = self._state.sharded_state(
|
|
81
|
+
shard_id=microbatch_idx
|
|
82
|
+
)
|
|
70
83
|
|
|
71
84
|
computation = self._task.compute_loss(ComputeLossContext(
|
|
72
85
|
pipeline_results=pipeline_outputs,
|
|
@@ -84,3 +97,53 @@ class LossComputer:
|
|
|
84
97
|
state[STATE_LOSS_WEIGHT] = loss_weight[None]
|
|
85
98
|
|
|
86
99
|
return loss * loss_weight
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class InferenceProcessor(PipelineOutputsProcessor[None]):
|
|
103
|
+
"""
|
|
104
|
+
Handles the processing of model outputs during inference or evaluation.
|
|
105
|
+
|
|
106
|
+
This component retrieves the appropriate state context
|
|
107
|
+
and delegates the output processing logic to the user-defined inference task.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
state: PipelineStateHandler,
|
|
113
|
+
task: InferenceTask
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
Constructs a new ModelOutputsProcessor.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
state: Handler for managing global and sharded pipeline states.
|
|
120
|
+
task: The user-defined inference task containing processing logic.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
self._state = state
|
|
124
|
+
self._task = task
|
|
125
|
+
|
|
126
|
+
def __call__(
|
|
127
|
+
self,
|
|
128
|
+
pipeline_outputs: dict[str, torch.Tensor],
|
|
129
|
+
microbatch_idx: int
|
|
130
|
+
) -> None:
|
|
131
|
+
"""
|
|
132
|
+
Processes model outputs for a specific microbatch or full batch.
|
|
133
|
+
|
|
134
|
+
This method retrieves the relevant state (scoped by microbatch index if provided)
|
|
135
|
+
and invokes the task's output processing logic.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
pipeline_outputs: Dictionary containing model output tensors.
|
|
139
|
+
microbatch_idx: Index of the current microbatch, or None if not using microbatching.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
state = self._state.sharded_state(
|
|
143
|
+
shard_id=microbatch_idx
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self._task.process_outputs(ProcessOutputsContext(
|
|
147
|
+
pipeline_results=pipeline_outputs,
|
|
148
|
+
state=state
|
|
149
|
+
))
|
|
@@ -5,12 +5,17 @@ import torch
|
|
|
5
5
|
from d9d.core.dist_context import DistributedContext
|
|
6
6
|
from d9d.core.types import PyTree
|
|
7
7
|
from d9d.internals.pipeline_state import PipelineStateHandler
|
|
8
|
-
from d9d.loop.control import
|
|
8
|
+
from d9d.loop.control import (
|
|
9
|
+
BaseTask,
|
|
10
|
+
BuildForwardInputsContext,
|
|
11
|
+
InferenceTask,
|
|
12
|
+
TrainTask,
|
|
13
|
+
UpdateMetricsContext,
|
|
14
|
+
)
|
|
9
15
|
from d9d.metric.impl import ComposeMetric
|
|
10
16
|
from d9d.pipelining.factory.factory import PipelineScheduleInfo
|
|
11
17
|
|
|
12
|
-
from .
|
|
13
|
-
from .model_stage_factory import TrackedModules
|
|
18
|
+
from .pipeline_result_processing import STATE_LOSS, STATE_LOSS_WEIGHT
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
@dataclasses.dataclass(kw_only=True)
|
|
@@ -27,12 +32,34 @@ class ForwardResult:
|
|
|
27
32
|
loss_weight: torch.Tensor
|
|
28
33
|
|
|
29
34
|
|
|
35
|
+
def _run_pipeline(
|
|
36
|
+
task: BaseTask,
|
|
37
|
+
pipeline: PipelineScheduleInfo,
|
|
38
|
+
pipeline_state: PipelineStateHandler,
|
|
39
|
+
batch: PyTree
|
|
40
|
+
):
|
|
41
|
+
model_inputs = task.build_forward_inputs(
|
|
42
|
+
BuildForwardInputsContext(
|
|
43
|
+
batch=batch,
|
|
44
|
+
state=pipeline_state.global_state()
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
pipeline.schedule.configure_buffers(
|
|
48
|
+
inputs=model_inputs.inputs,
|
|
49
|
+
kwargs=model_inputs.kwargs,
|
|
50
|
+
sharding_spec=model_inputs.pipeline_sharding_spec
|
|
51
|
+
)
|
|
52
|
+
pipeline.schedule.step(
|
|
53
|
+
inputs=model_inputs.inputs,
|
|
54
|
+
kwargs=model_inputs.kwargs
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
30
58
|
class TrainTaskOperator:
|
|
31
59
|
"""
|
|
32
60
|
Orchestrates the execution of the forward and backward passes for a specific training task.
|
|
33
61
|
|
|
34
|
-
|
|
35
|
-
and pipeline-parallel execution. It manages input construction, schedule execution,
|
|
62
|
+
It manages input construction, schedule execution,
|
|
36
63
|
loss computation, and metric updates within the lifecycle of a single step.
|
|
37
64
|
"""
|
|
38
65
|
|
|
@@ -40,9 +67,7 @@ class TrainTaskOperator:
|
|
|
40
67
|
self,
|
|
41
68
|
dist_context: DistributedContext,
|
|
42
69
|
task: TrainTask,
|
|
43
|
-
|
|
44
|
-
tracked_modules: TrackedModules,
|
|
45
|
-
loss_computer: LossComputer,
|
|
70
|
+
pipeline: PipelineScheduleInfo,
|
|
46
71
|
pipeline_state: PipelineStateHandler,
|
|
47
72
|
metrics: ComposeMetric
|
|
48
73
|
):
|
|
@@ -52,48 +77,17 @@ class TrainTaskOperator:
|
|
|
52
77
|
Args:
|
|
53
78
|
dist_context: The distributed context.
|
|
54
79
|
task: The user-defined training task logic.
|
|
55
|
-
|
|
56
|
-
tracked_modules: The model modules being trained.
|
|
57
|
-
loss_computer: Component responsible for calculating loss from outputs.
|
|
80
|
+
pipeline: Information about the pipeline schedule.
|
|
58
81
|
pipeline_state: Handler for transient state storage during the step.
|
|
59
82
|
metrics: Metric collection to update after the pass.
|
|
60
83
|
"""
|
|
61
84
|
|
|
62
85
|
self._dist_context = dist_context
|
|
63
86
|
self._task = task
|
|
64
|
-
self.
|
|
65
|
-
self._tracked_modules = tracked_modules
|
|
66
|
-
self._loss_computer = loss_computer
|
|
87
|
+
self._pipeline = pipeline
|
|
67
88
|
self._pipeline_state = pipeline_state
|
|
68
89
|
self._metrics = metrics
|
|
69
90
|
|
|
70
|
-
def _forward_backward_pipelining(self, model_inputs: BuildForwardInputsResult):
|
|
71
|
-
if self._pp_schedule is None:
|
|
72
|
-
raise ValueError("Cannot run pipelined pass if pipelining is disabled")
|
|
73
|
-
|
|
74
|
-
self._pp_schedule.schedule.configure_buffers(
|
|
75
|
-
inputs=model_inputs.inputs,
|
|
76
|
-
kwargs=model_inputs.kwargs,
|
|
77
|
-
sharding_spec=model_inputs.pipeline_sharding_spec
|
|
78
|
-
)
|
|
79
|
-
self._pp_schedule.schedule.step(
|
|
80
|
-
inputs=model_inputs.inputs,
|
|
81
|
-
kwargs=model_inputs.kwargs
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _forward_backward_regular(self, model_inputs: BuildForwardInputsResult):
|
|
85
|
-
pipeline_outputs = self._tracked_modules(
|
|
86
|
-
**model_inputs.inputs,
|
|
87
|
-
**model_inputs.kwargs
|
|
88
|
-
)
|
|
89
|
-
loss = self._loss_computer.compute_loss_mul_weight(
|
|
90
|
-
pipeline_outputs=pipeline_outputs,
|
|
91
|
-
microbatch_idx=None
|
|
92
|
-
)
|
|
93
|
-
# free to avoid bwd peaking memory
|
|
94
|
-
del pipeline_outputs
|
|
95
|
-
loss.backward()
|
|
96
|
-
|
|
97
91
|
def forward_backward(self, batch: PyTree) -> ForwardResult | None:
|
|
98
92
|
"""
|
|
99
93
|
Executes the forward and backward passes for a single batch.
|
|
@@ -117,27 +111,17 @@ class TrainTaskOperator:
|
|
|
117
111
|
|
|
118
112
|
try:
|
|
119
113
|
# Do forward and backward pass
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
114
|
+
_run_pipeline(
|
|
115
|
+
pipeline_state=self._pipeline_state,
|
|
116
|
+
task=self._task,
|
|
117
|
+
pipeline=self._pipeline,
|
|
118
|
+
batch=batch
|
|
125
119
|
)
|
|
126
120
|
|
|
127
|
-
if self._dist_context.mesh_params.has_pipeline_parallel:
|
|
128
|
-
self._forward_backward_pipelining(model_inputs)
|
|
129
|
-
else:
|
|
130
|
-
self._forward_backward_regular(model_inputs)
|
|
131
|
-
|
|
132
121
|
# Update metrics if possible
|
|
133
122
|
|
|
134
123
|
pipeline_state = self._pipeline_state.global_state()
|
|
135
|
-
|
|
136
|
-
if (
|
|
137
|
-
self._dist_context.mesh_params.has_pipeline_parallel and
|
|
138
|
-
self._pp_schedule is not None and
|
|
139
|
-
not self._pp_schedule.has_last_stage
|
|
140
|
-
):
|
|
124
|
+
if not self._pipeline.has_last_stage:
|
|
141
125
|
return None
|
|
142
126
|
|
|
143
127
|
self._task.update_metrics(UpdateMetricsContext(
|
|
@@ -150,3 +134,59 @@ class TrainTaskOperator:
|
|
|
150
134
|
)
|
|
151
135
|
finally:
|
|
152
136
|
self._pipeline_state.reset()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class InferenceTaskOperator:
|
|
140
|
+
"""
|
|
141
|
+
Orchestrates the execution of the forward pass for a specific inference task.
|
|
142
|
+
|
|
143
|
+
It manages input
|
|
144
|
+
construction, schedule execution, and state lifecycle management.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
dist_context: DistributedContext,
|
|
150
|
+
task: InferenceTask,
|
|
151
|
+
pipeline: PipelineScheduleInfo,
|
|
152
|
+
pipeline_state: PipelineStateHandler
|
|
153
|
+
):
|
|
154
|
+
"""
|
|
155
|
+
Constructs the InferenceTaskOperator.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
dist_context: The distributed context.
|
|
159
|
+
task: The user-defined inference task logic.
|
|
160
|
+
pipeline: Information about the pipeline schedule.
|
|
161
|
+
pipeline_state: Handler for transient state storage during the step.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
self._dist_context = dist_context
|
|
165
|
+
self._task = task
|
|
166
|
+
self._pipeline = pipeline
|
|
167
|
+
self._pipeline_state = pipeline_state
|
|
168
|
+
|
|
169
|
+
def forward(self, batch: PyTree) -> None:
|
|
170
|
+
"""
|
|
171
|
+
Executes the forward pass for a single batch.
|
|
172
|
+
|
|
173
|
+
This method handles:
|
|
174
|
+
|
|
175
|
+
1. Context preparation and input building via the `InferenceTask`.
|
|
176
|
+
2. Execution via Pipeline Parallel schedule.
|
|
177
|
+
3. Reliable cleanup of the pipeline state.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
batch: The input batch data.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
# Do forward pass
|
|
185
|
+
_run_pipeline(
|
|
186
|
+
pipeline_state=self._pipeline_state,
|
|
187
|
+
task=self._task,
|
|
188
|
+
pipeline=self._pipeline,
|
|
189
|
+
batch=batch
|
|
190
|
+
)
|
|
191
|
+
finally:
|
|
192
|
+
self._pipeline_state.reset()
|
d9d/loop/config/config.py
CHANGED
|
@@ -190,7 +190,7 @@ class TrainerConfig(BaseModel):
|
|
|
190
190
|
batching: BatchingConfig
|
|
191
191
|
data_loading: DataLoadingConfig
|
|
192
192
|
logging: JobLoggerConfig
|
|
193
|
-
pipelining: PipeliningConfig
|
|
193
|
+
pipelining: PipeliningConfig
|
|
194
194
|
model_stage_factory: ModelStageFactoryConfig
|
|
195
195
|
determinism: DeterminismConfig
|
|
196
196
|
gc: GarbageCollectionConfig
|
d9d/loop/control/task.py
CHANGED
|
@@ -252,11 +252,11 @@ class ProcessOutputsContext:
|
|
|
252
252
|
Context data provided to process outputs during inference.
|
|
253
253
|
|
|
254
254
|
Attributes:
|
|
255
|
-
|
|
255
|
+
pipeline_results: The outputs returned by the model's forward pass.
|
|
256
256
|
state: The current state of the pipeline.
|
|
257
257
|
"""
|
|
258
258
|
|
|
259
|
-
|
|
259
|
+
pipeline_results: dict[str, torch.Tensor]
|
|
260
260
|
state: "PipelineState"
|
|
261
261
|
|
|
262
262
|
|