d9d 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
d9d/kernel/swiglu/op.py CHANGED
@@ -3,6 +3,15 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
 
6
+ def _size_bucket(n_elements: int) -> int:
7
+ # different auto-tuning for small and asymptotically large kernels
8
+ # perhaps we could extend this in future?
9
+ if n_elements < 8192:
10
+ return 0
11
+ else:
12
+ return 1
13
+
14
+
6
15
  @triton.autotune(
7
16
  configs=[
8
17
  triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
@@ -11,7 +20,7 @@ import triton.language as tl
11
20
  triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
12
21
  triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
13
22
  ],
14
- key=["n_elements"]
23
+ key=["size_bucket"]
15
24
  )
16
25
  @triton.jit
17
26
  def _silu_mul_kernel(
@@ -19,6 +28,7 @@ def _silu_mul_kernel(
19
28
  y_ptr: torch.Tensor,
20
29
  out_ptr: torch.Tensor,
21
30
  n_elements: int,
31
+ size_bucket: int, # used for autotuning
22
32
  BLOCK_SIZE: tl.constexpr,
23
33
  ):
24
34
  # prepare
@@ -72,7 +82,8 @@ def silu_mul_forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
72
82
 
73
83
  _silu_mul_kernel[_grid](
74
84
  x, y, out,
75
- n_elements
85
+ n_elements,
86
+ size_bucket=_size_bucket(n_elements)
76
87
  )
77
88
 
78
89
  return out
@@ -86,7 +97,7 @@ def silu_mul_forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
86
97
  triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
87
98
  triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
88
99
  ],
89
- key=["n_elements"]
100
+ key=["size_bucket"]
90
101
  )
91
102
  @triton.jit
92
103
  def _silu_mul_backward_kernel(
@@ -96,6 +107,7 @@ def _silu_mul_backward_kernel(
96
107
  grad_x_ptr: torch.Tensor,
97
108
  grad_y_ptr: torch.Tensor,
98
109
  n_elements: int,
110
+ size_bucket: int, # used for autotuning
99
111
  BLOCK_SIZE: tl.constexpr
100
112
  ):
101
113
  # prepare
@@ -161,7 +173,8 @@ def silu_mul_backward(
161
173
  _silu_mul_backward_kernel[_grid](
162
174
  grad_output, x, y,
163
175
  grad_x, grad_y,
164
- n_elements
176
+ n_elements,
177
+ size_bucket=_size_bucket(n_elements)
165
178
  )
166
179
 
167
180
  return grad_x, grad_y
@@ -6,13 +6,13 @@ from .gradient_clipper import GradientClipper
6
6
  from .gradient_manager import GradientManager
7
7
  from .job_logger import JobLogger
8
8
  from .job_profiler import JobProfiler
9
- from .loss_computer import LossComputer
10
9
  from .model_stage_exporter import ModelStageExporter
11
10
  from .model_stage_factory import ModelStageFactory, TrackedModules
12
11
  from .optimizer_factory import OptimizerFactory
12
+ from .pipeline_result_processing import InferenceProcessor, LossComputer, PipelineOutputsProcessor
13
13
  from .stepper import Stepper
14
+ from .task_operator import ForwardResult, InferenceTaskOperator, TrainTaskOperator
14
15
  from .timeout_manager import TimeoutManager
15
- from .train_task_operator import ForwardResult, TrainTaskOperator
16
16
 
17
17
  __all__ = [
18
18
  "BatchMaths",
@@ -20,6 +20,8 @@ __all__ = [
20
20
  "ForwardResult",
21
21
  "GradientClipper",
22
22
  "GradientManager",
23
+ "InferenceProcessor",
24
+ "InferenceTaskOperator",
23
25
  "JobLogger",
24
26
  "JobProfiler",
25
27
  "LossComputer",
@@ -27,6 +29,7 @@ __all__ = [
27
29
  "ModelStageExporter",
28
30
  "ModelStageFactory",
29
31
  "OptimizerFactory",
32
+ "PipelineOutputsProcessor",
30
33
  "StateCheckpointer",
31
34
  "Stepper",
32
35
  "TimeoutManager",
@@ -15,7 +15,7 @@ from d9d.pipelining.api import PipelineStageInfo
15
15
  from d9d.pipelining.factory.factory import PipelineScheduleInfo, build_schedule
16
16
 
17
17
  from .batch_maths import BatchMaths
18
- from .loss_computer import LossComputer
18
+ from .pipeline_result_processing import PipelineOutputsProcessor
19
19
 
20
20
  StatefulPredicate = Callable[[str, torch.Tensor], bool]
21
21
  """Determines if a specific parameter or buffer should be included in the state dictionary."""
@@ -51,28 +51,6 @@ class TrackedModules(Stateful):
51
51
  self._modules = modules
52
52
  self._stateful_predicate = stateful_predicate
53
53
 
54
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
55
- """
56
- Forwards execution to the only pipeline stage.
57
-
58
- This method is only valid when pipeline parallelism is disabled.
59
-
60
- Args:
61
- *args: Positional arguments passed to the module.
62
- **kwargs: Keyword arguments passed to the module.
63
-
64
- Returns:
65
- The output of the model execution.
66
-
67
- Raises:
68
- ValueError: If pipeline parallelism is configured.
69
- """
70
-
71
- if self._dist_context.mesh_params.has_pipeline_parallel:
72
- raise ValueError("You cannot call tracked modules when using pipelining")
73
-
74
- return self._modules[0](*args, **kwargs)
75
-
76
54
  @property
77
55
  def modules(self) -> list[nn.Module]:
78
56
  """Returns the list of underlying PyTorch model modules."""
@@ -159,8 +137,8 @@ class ModelStageFactory:
159
137
  dist_context: DistributedContext,
160
138
  batch_maths: BatchMaths,
161
139
  config_model: ModelStageFactoryConfig,
162
- config_pipelining: PipeliningConfig | None,
163
- loss_computer: LossComputer | None
140
+ config_pipelining: PipeliningConfig,
141
+ pipeline_callback: PipelineOutputsProcessor
164
142
  ):
165
143
  """Constructs a ModelStageFactory object."""
166
144
 
@@ -169,7 +147,7 @@ class ModelStageFactory:
169
147
  self._config_model = config_model
170
148
  self._config_pipelining = config_pipelining
171
149
  self._batch_maths = batch_maths
172
- self._loss_computer = loss_computer
150
+ self._pipeline_callback = pipeline_callback
173
151
 
174
152
  def _build_model_stage(self, stage: PipelineStageInfo) -> nn.Module:
175
153
  # create a model with no real memory occupied
@@ -218,21 +196,13 @@ class ModelStageFactory:
218
196
 
219
197
  def build_pipeline_and_modules(
220
198
  self
221
- ) -> tuple[PipelineScheduleInfo | None, TrackedModules]:
199
+ ) -> tuple[PipelineScheduleInfo, TrackedModules]:
222
200
  """
223
201
  Constructs the execution schedule and the model container.
224
202
 
225
- If pipeline parallelism is enabled, this orchestrates the creation of a
226
- distributed pipeline schedule.
227
-
228
- Otherwise, it simply builds a standalone model stage.
229
-
230
203
  Returns:
231
- The pipeline schedule information (or None if no pipelining).
204
+ The pipeline schedule information.
232
205
  The `TrackedModules` instance wrapping the created model stage(s).
233
-
234
- Raises:
235
- ValueError: If pipelining configuration is missing but a pipeline is requested.
236
206
  """
237
207
 
238
208
  if self._config_model.checkpoint_only_trainable_parameters:
@@ -240,22 +210,12 @@ class ModelStageFactory:
240
210
  else:
241
211
  stateful_predicate = _stateful_predicate_always
242
212
 
243
- if self._dist_context.mesh_params.has_pipeline_parallel:
244
- if self._config_pipelining is None:
245
- raise ValueError("Pipelining is enabled, but not configured")
246
-
247
- loss_fn = self._loss_computer.compute_loss_mul_weight if self._loss_computer is not None else None
248
-
249
- schedule, modules = build_schedule(
250
- dist_context=self._dist_context,
251
- n_microbatches=self._batch_maths.num_microbatches_pipelining,
252
- schedule_config=self._config_pipelining.schedule,
253
- model_provider=self._build_model_stage,
254
- loss_fn=loss_fn
255
- )
256
-
257
- return schedule, TrackedModules(self._dist_context, modules, stateful_predicate)
258
- else:
259
- model = self._build_model_stage(PipelineStageInfo(num_stages=1, current_stage=0))
213
+ schedule, modules = build_schedule(
214
+ dist_context=self._dist_context,
215
+ n_microbatches=self._batch_maths.num_microbatches_pipelining,
216
+ schedule_config=self._config_pipelining.schedule,
217
+ model_provider=self._build_model_stage,
218
+ callback=self._pipeline_callback
219
+ )
260
220
 
261
- return None, TrackedModules(self._dist_context, [model], stateful_predicate)
221
+ return schedule, TrackedModules(self._dist_context, modules, stateful_predicate)
@@ -1,7 +1,10 @@
1
+ import abc
2
+ from typing import Generic, TypeVar
3
+
1
4
  import torch
2
5
 
3
6
  from d9d.internals.pipeline_state import PipelineStateHandler
4
- from d9d.loop.control import ComputeLossContext, TrainTask
7
+ from d9d.loop.control import ComputeLossContext, InferenceTask, ProcessOutputsContext, TrainTask
5
8
 
6
9
  from .stepper import Stepper
7
10
 
@@ -9,7 +12,20 @@ STATE_LOSS = "__internal_loss"
9
12
  STATE_LOSS_WEIGHT = "__internal_loss_weight"
10
13
 
11
14
 
12
- class LossComputer:
15
+ TOutput = TypeVar("TOutput")
16
+
17
+
18
+ class PipelineOutputsProcessor(abc.ABC, Generic[TOutput]):
19
+ @abc.abstractmethod
20
+ def __call__(
21
+ self,
22
+ pipeline_outputs: dict[str, torch.Tensor],
23
+ microbatch_idx: int
24
+ ) -> TOutput:
25
+ ...
26
+
27
+
28
+ class LossComputer(PipelineOutputsProcessor[torch.Tensor]):
13
29
  """
14
30
  Handles the computation of loss values and their integration into the pipeline state.
15
31
 
@@ -38,10 +54,10 @@ class LossComputer:
38
54
  self._task = task
39
55
  self._stepper = stepper
40
56
 
41
- def compute_loss_mul_weight(
57
+ def __call__(
42
58
  self,
43
59
  pipeline_outputs: dict[str, torch.Tensor],
44
- microbatch_idx: int | None
60
+ microbatch_idx: int
45
61
  ) -> torch.Tensor:
46
62
  """
47
63
  Computes the weighted loss for a specific sharded microbatch or the full microbatch.
@@ -61,12 +77,9 @@ class LossComputer:
61
77
  The calculated loss multiplied by its weight.
62
78
  """
63
79
 
64
- if microbatch_idx is None:
65
- state = self._state.global_state()
66
- else:
67
- state = self._state.sharded_state(
68
- shard_id=microbatch_idx
69
- )
80
+ state = self._state.sharded_state(
81
+ shard_id=microbatch_idx
82
+ )
70
83
 
71
84
  computation = self._task.compute_loss(ComputeLossContext(
72
85
  pipeline_results=pipeline_outputs,
@@ -84,3 +97,53 @@ class LossComputer:
84
97
  state[STATE_LOSS_WEIGHT] = loss_weight[None]
85
98
 
86
99
  return loss * loss_weight
100
+
101
+
102
+ class InferenceProcessor(PipelineOutputsProcessor[None]):
103
+ """
104
+ Handles the processing of model outputs during inference or evaluation.
105
+
106
+ This component retrieves the appropriate state context
107
+ and delegates the output processing logic to the user-defined inference task.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ state: PipelineStateHandler,
113
+ task: InferenceTask
114
+ ):
115
+ """
116
+ Constructs a new ModelOutputsProcessor.
117
+
118
+ Args:
119
+ state: Handler for managing global and sharded pipeline states.
120
+ task: The user-defined inference task containing processing logic.
121
+ """
122
+
123
+ self._state = state
124
+ self._task = task
125
+
126
+ def __call__(
127
+ self,
128
+ pipeline_outputs: dict[str, torch.Tensor],
129
+ microbatch_idx: int
130
+ ) -> None:
131
+ """
132
+ Processes model outputs for a specific microbatch or full batch.
133
+
134
+ This method retrieves the relevant state (scoped by microbatch index if provided)
135
+ and invokes the task's output processing logic.
136
+
137
+ Args:
138
+ pipeline_outputs: Dictionary containing model output tensors.
139
+ microbatch_idx: Index of the current microbatch, or None if not using microbatching.
140
+ """
141
+
142
+ state = self._state.sharded_state(
143
+ shard_id=microbatch_idx
144
+ )
145
+
146
+ self._task.process_outputs(ProcessOutputsContext(
147
+ pipeline_results=pipeline_outputs,
148
+ state=state
149
+ ))
@@ -5,12 +5,17 @@ import torch
5
5
  from d9d.core.dist_context import DistributedContext
6
6
  from d9d.core.types import PyTree
7
7
  from d9d.internals.pipeline_state import PipelineStateHandler
8
- from d9d.loop.control import BuildForwardInputsContext, BuildForwardInputsResult, TrainTask, UpdateMetricsContext
8
+ from d9d.loop.control import (
9
+ BaseTask,
10
+ BuildForwardInputsContext,
11
+ InferenceTask,
12
+ TrainTask,
13
+ UpdateMetricsContext,
14
+ )
9
15
  from d9d.metric.impl import ComposeMetric
10
16
  from d9d.pipelining.factory.factory import PipelineScheduleInfo
11
17
 
12
- from .loss_computer import STATE_LOSS, STATE_LOSS_WEIGHT, LossComputer
13
- from .model_stage_factory import TrackedModules
18
+ from .pipeline_result_processing import STATE_LOSS, STATE_LOSS_WEIGHT
14
19
 
15
20
 
16
21
  @dataclasses.dataclass(kw_only=True)
@@ -27,12 +32,34 @@ class ForwardResult:
27
32
  loss_weight: torch.Tensor
28
33
 
29
34
 
35
+ def _run_pipeline(
36
+ task: BaseTask,
37
+ pipeline: PipelineScheduleInfo,
38
+ pipeline_state: PipelineStateHandler,
39
+ batch: PyTree
40
+ ):
41
+ model_inputs = task.build_forward_inputs(
42
+ BuildForwardInputsContext(
43
+ batch=batch,
44
+ state=pipeline_state.global_state()
45
+ )
46
+ )
47
+ pipeline.schedule.configure_buffers(
48
+ inputs=model_inputs.inputs,
49
+ kwargs=model_inputs.kwargs,
50
+ sharding_spec=model_inputs.pipeline_sharding_spec
51
+ )
52
+ pipeline.schedule.step(
53
+ inputs=model_inputs.inputs,
54
+ kwargs=model_inputs.kwargs
55
+ )
56
+
57
+
30
58
  class TrainTaskOperator:
31
59
  """
32
60
  Orchestrates the execution of the forward and backward passes for a specific training task.
33
61
 
34
- This class abstracts the difference between standard execution
35
- and pipeline-parallel execution. It manages input construction, schedule execution,
62
+ It manages input construction, schedule execution,
36
63
  loss computation, and metric updates within the lifecycle of a single step.
37
64
  """
38
65
 
@@ -40,9 +67,7 @@ class TrainTaskOperator:
40
67
  self,
41
68
  dist_context: DistributedContext,
42
69
  task: TrainTask,
43
- pp_schedule: PipelineScheduleInfo | None,
44
- tracked_modules: TrackedModules,
45
- loss_computer: LossComputer,
70
+ pipeline: PipelineScheduleInfo,
46
71
  pipeline_state: PipelineStateHandler,
47
72
  metrics: ComposeMetric
48
73
  ):
@@ -52,48 +77,17 @@ class TrainTaskOperator:
52
77
  Args:
53
78
  dist_context: The distributed context.
54
79
  task: The user-defined training task logic.
55
- pp_schedule: Information about the pipeline schedule.
56
- tracked_modules: The model modules being trained.
57
- loss_computer: Component responsible for calculating loss from outputs.
80
+ pipeline: Information about the pipeline schedule.
58
81
  pipeline_state: Handler for transient state storage during the step.
59
82
  metrics: Metric collection to update after the pass.
60
83
  """
61
84
 
62
85
  self._dist_context = dist_context
63
86
  self._task = task
64
- self._pp_schedule = pp_schedule
65
- self._tracked_modules = tracked_modules
66
- self._loss_computer = loss_computer
87
+ self._pipeline = pipeline
67
88
  self._pipeline_state = pipeline_state
68
89
  self._metrics = metrics
69
90
 
70
- def _forward_backward_pipelining(self, model_inputs: BuildForwardInputsResult):
71
- if self._pp_schedule is None:
72
- raise ValueError("Cannot run pipelined pass if pipelining is disabled")
73
-
74
- self._pp_schedule.schedule.configure_buffers(
75
- inputs=model_inputs.inputs,
76
- kwargs=model_inputs.kwargs,
77
- sharding_spec=model_inputs.pipeline_sharding_spec
78
- )
79
- self._pp_schedule.schedule.step(
80
- inputs=model_inputs.inputs,
81
- kwargs=model_inputs.kwargs
82
- )
83
-
84
- def _forward_backward_regular(self, model_inputs: BuildForwardInputsResult):
85
- pipeline_outputs = self._tracked_modules(
86
- **model_inputs.inputs,
87
- **model_inputs.kwargs
88
- )
89
- loss = self._loss_computer.compute_loss_mul_weight(
90
- pipeline_outputs=pipeline_outputs,
91
- microbatch_idx=None
92
- )
93
- # free to avoid bwd peaking memory
94
- del pipeline_outputs
95
- loss.backward()
96
-
97
91
  def forward_backward(self, batch: PyTree) -> ForwardResult | None:
98
92
  """
99
93
  Executes the forward and backward passes for a single batch.
@@ -117,27 +111,17 @@ class TrainTaskOperator:
117
111
 
118
112
  try:
119
113
  # Do forward and backward pass
120
- model_inputs = self._task.build_forward_inputs(
121
- BuildForwardInputsContext(
122
- batch=batch,
123
- state=self._pipeline_state.global_state()
124
- )
114
+ _run_pipeline(
115
+ pipeline_state=self._pipeline_state,
116
+ task=self._task,
117
+ pipeline=self._pipeline,
118
+ batch=batch
125
119
  )
126
120
 
127
- if self._dist_context.mesh_params.has_pipeline_parallel:
128
- self._forward_backward_pipelining(model_inputs)
129
- else:
130
- self._forward_backward_regular(model_inputs)
131
-
132
121
  # Update metrics if possible
133
122
 
134
123
  pipeline_state = self._pipeline_state.global_state()
135
-
136
- if (
137
- self._dist_context.mesh_params.has_pipeline_parallel and
138
- self._pp_schedule is not None and
139
- not self._pp_schedule.has_last_stage
140
- ):
124
+ if not self._pipeline.has_last_stage:
141
125
  return None
142
126
 
143
127
  self._task.update_metrics(UpdateMetricsContext(
@@ -150,3 +134,59 @@ class TrainTaskOperator:
150
134
  )
151
135
  finally:
152
136
  self._pipeline_state.reset()
137
+
138
+
139
+ class InferenceTaskOperator:
140
+ """
141
+ Orchestrates the execution of the forward pass for a specific inference task.
142
+
143
+ It manages input
144
+ construction, schedule execution, and state lifecycle management.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ dist_context: DistributedContext,
150
+ task: InferenceTask,
151
+ pipeline: PipelineScheduleInfo,
152
+ pipeline_state: PipelineStateHandler
153
+ ):
154
+ """
155
+ Constructs the InferenceTaskOperator.
156
+
157
+ Args:
158
+ dist_context: The distributed context.
159
+ task: The user-defined inference task logic.
160
+ pipeline: Information about the pipeline schedule.
161
+ pipeline_state: Handler for transient state storage during the step.
162
+ """
163
+
164
+ self._dist_context = dist_context
165
+ self._task = task
166
+ self._pipeline = pipeline
167
+ self._pipeline_state = pipeline_state
168
+
169
+ def forward(self, batch: PyTree) -> None:
170
+ """
171
+ Executes the forward pass for a single batch.
172
+
173
+ This method handles:
174
+
175
+ 1. Context preparation and input building via the `InferenceTask`.
176
+ 2. Execution via Pipeline Parallel schedule.
177
+ 3. Reliable cleanup of the pipeline state.
178
+
179
+ Args:
180
+ batch: The input batch data.
181
+ """
182
+
183
+ try:
184
+ # Do forward pass
185
+ _run_pipeline(
186
+ pipeline_state=self._pipeline_state,
187
+ task=self._task,
188
+ pipeline=self._pipeline,
189
+ batch=batch
190
+ )
191
+ finally:
192
+ self._pipeline_state.reset()
d9d/loop/config/config.py CHANGED
@@ -190,7 +190,7 @@ class TrainerConfig(BaseModel):
190
190
  batching: BatchingConfig
191
191
  data_loading: DataLoadingConfig
192
192
  logging: JobLoggerConfig
193
- pipelining: PipeliningConfig | None
193
+ pipelining: PipeliningConfig
194
194
  model_stage_factory: ModelStageFactoryConfig
195
195
  determinism: DeterminismConfig
196
196
  gc: GarbageCollectionConfig
d9d/loop/control/task.py CHANGED
@@ -252,11 +252,11 @@ class ProcessOutputsContext:
252
252
  Context data provided to process outputs during inference.
253
253
 
254
254
  Attributes:
255
- outputs: The outputs returned by the model's forward pass.
255
+ pipeline_results: The outputs returned by the model's forward pass.
256
256
  state: The current state of the pipeline.
257
257
  """
258
258
 
259
- outputs: dict[str, torch.Tensor]
259
+ pipeline_results: dict[str, torch.Tensor]
260
260
  state: "PipelineState"
261
261
 
262
262
 
d9d/loop/run/__init__.py CHANGED
@@ -1,6 +1,9 @@
1
+ from .inference import Inference, InferenceConfigurator
1
2
  from .train import Trainer, TrainingConfigurator
2
3
 
3
4
  __all__ = [
5
+ "Inference",
6
+ "InferenceConfigurator",
4
7
  "Trainer",
5
8
  "TrainingConfigurator"
6
9
  ]