d9d 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,256 @@
1
+
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from d9d.core.dist_context import DeviceMeshParameters
6
+ from d9d.internals.determinism import set_seeds
7
+ from d9d.internals.pipeline_state import PipelineStateHandler
8
+ from d9d.loop.component import (
9
+ BatchMaths,
10
+ DataLoaderFactory,
11
+ InferenceProcessor,
12
+ InferenceTaskOperator,
13
+ JobProfiler,
14
+ ManualGarbageCollector,
15
+ ModelStageFactory,
16
+ StateCheckpointer,
17
+ Stepper,
18
+ TimeoutManager,
19
+ )
20
+ from d9d.loop.config import InferenceConfig, PipeliningConfig
21
+ from d9d.loop.control import (
22
+ DatasetProvider,
23
+ FinalizeContext,
24
+ InferenceTaskProvider,
25
+ InferenceTaskProviderContext,
26
+ ModelProvider,
27
+ )
28
+ from d9d.loop.state import InferenceJobState
29
+ from d9d.pipelining.factory import PipelineScheduleInferenceConfig
30
+
31
+
32
+ class InferenceConfigurator:
33
+ """
34
+ Orchestrates the assembly of the distributed inference environment.
35
+
36
+ This class binds the infrastructure configuration (DeviceMesh), the inference
37
+ parameters, and the user-defined logic (Providers) to create a fully
38
+ initialized state object capable of running the inference loop.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ mesh: DeviceMeshParameters,
44
+ parameters: InferenceConfig,
45
+ task_provider: InferenceTaskProvider,
46
+ model_provider: ModelProvider,
47
+ data_provider: DatasetProvider
48
+ ):
49
+ """
50
+ Constructs a configurator capable of building the full inference state.
51
+
52
+ Args:
53
+ mesh: Definition of the distributed device mesh topology.
54
+ parameters: The global configuration object for inference.
55
+ task_provider: Factory for creating the inference task logic.
56
+ model_provider: Factory for defining and creating model stages.
57
+ data_provider: Factory for providing inference datasets.
58
+ """
59
+
60
+ self._mesh = mesh
61
+ self._parameters = parameters
62
+ self._task_provider = task_provider
63
+ self._model_provider = model_provider
64
+ self._data_provider = data_provider
65
+
66
+ def _build_new_state(self) -> InferenceJobState:
67
+ dist_context = self._mesh.build()
68
+
69
+ pipelining_config = PipeliningConfig(
70
+ schedule=PipelineScheduleInferenceConfig()
71
+ )
72
+
73
+ set_seeds(dist_context, seed=self._parameters.determinism.base_seed)
74
+
75
+ task = self._task_provider(InferenceTaskProviderContext(
76
+ dist_context=dist_context
77
+ ))
78
+
79
+ batch_maths = BatchMaths(
80
+ dist_context=dist_context,
81
+ config_batching=self._parameters.batching,
82
+ config_pipelining=pipelining_config
83
+ )
84
+
85
+ data_loader_factory = DataLoaderFactory(
86
+ dist_context=dist_context,
87
+ provider=self._data_provider,
88
+ config_data_loading=self._parameters.data_loading,
89
+ batch_maths=batch_maths
90
+ )
91
+ data_loader_infer = data_loader_factory.build_dataloader_for_infer_job()
92
+
93
+ stepper = Stepper(
94
+ initial_step=1,
95
+ total_steps=len(data_loader_infer)
96
+ )
97
+
98
+ pipeline_state_handler = PipelineStateHandler(
99
+ sharding_spec={},
100
+ num_shards=batch_maths.num_microbatches_pipelining
101
+ )
102
+
103
+ processor = InferenceProcessor(
104
+ state=pipeline_state_handler,
105
+ task=task
106
+ )
107
+
108
+ schedule, modules = ModelStageFactory(
109
+ model_provider=self._model_provider,
110
+ dist_context=dist_context,
111
+ config_model=self._parameters.model_stage_factory,
112
+ config_pipelining=pipelining_config,
113
+ batch_maths=batch_maths,
114
+ pipeline_callback=processor
115
+ ).build_pipeline_and_modules()
116
+
117
+ task_operator = InferenceTaskOperator(
118
+ dist_context=dist_context,
119
+ task=task,
120
+ pipeline=schedule,
121
+ pipeline_state=pipeline_state_handler
122
+ )
123
+
124
+ gc = ManualGarbageCollector(
125
+ dist_ctx=dist_context,
126
+ config=self._parameters.gc,
127
+ step=stepper
128
+ )
129
+
130
+ checkpointer = StateCheckpointer(
131
+ dist_context=dist_context,
132
+ stepper=stepper,
133
+ config=self._parameters.checkpointing,
134
+ gc=gc,
135
+ run_name=None
136
+ )
137
+
138
+ profiler = JobProfiler(
139
+ dist_context=dist_context,
140
+ stepper=stepper,
141
+ config=self._parameters.profiling
142
+ )
143
+
144
+ timeout_manager = TimeoutManager(
145
+ dist_context=dist_context,
146
+ config=self._parameters.timeout
147
+ )
148
+
149
+ return InferenceJobState(
150
+ dist_context=dist_context,
151
+ data_loader=data_loader_infer,
152
+ stepper=stepper,
153
+ tracked_modules=modules,
154
+ garbage_collector=gc,
155
+ batch_maths=batch_maths,
156
+ checkpointer=checkpointer,
157
+ task=task,
158
+ profiler=profiler,
159
+ timeout_manager=timeout_manager,
160
+ task_operator=task_operator
161
+ )
162
+
163
+ def configure(self) -> "Inference":
164
+ """
165
+ Instantiates all inference components and returns a configured Inference engine.
166
+
167
+ This method triggers the creation of the distributed context, sets seeds,
168
+ builds the model, data loaders, and attaches all auxiliary components.
169
+
170
+ Returns:
171
+ Inference: A ready-to-use inference engine instance encapsulating the job state.
172
+ """
173
+
174
+ state = self._build_new_state()
175
+
176
+ return Inference(state)
177
+
178
+
179
+ class Inference:
180
+ """
181
+ The main execution engine for running a distributed inference job.
182
+
183
+ This class manages the inference loop, lifecycle events, distributed synchronization,
184
+ and periodic side-effects (profiling, checkpointing). It ensures the model is in
185
+ evaluation mode and runs within a `torch.inference_mode` context.
186
+ """
187
+
188
+ def __init__(self, state: InferenceJobState):
189
+ """
190
+ Constructs an Inference engine from a pre-built job state.
191
+
192
+ Args:
193
+ state: The encapsulated state object containing all initialized components.
194
+ """
195
+
196
+ self._state = state
197
+
198
+ def _enable_eval_mode(self):
199
+ for module in self._state.tracked_modules.modules:
200
+ module.eval()
201
+
202
+ def infer(self):
203
+ """
204
+ Executes the full inference workflow.
205
+
206
+ This method:
207
+
208
+ 1. Waits for world synchronization.
209
+ 2. Loads the latest checkpoint if available.
210
+ 3. Iterates through the data loader.
211
+ 4. Executes the pipeline forward pass for every batch.
212
+ 5. Handles periodic garbage collection and profiling.
213
+ 6. Finalizes the task upon completion.
214
+ """
215
+
216
+ with torch.inference_mode():
217
+ self._enable_eval_mode()
218
+
219
+ self._state.dist_context.logger.info("Waiting for the world to start job")
220
+ self._state.dist_context.wait_world()
221
+ self._state.dist_context.logger.info("Trying to load last checkpoint before doing anything else")
222
+ self._state.checkpointer.load_last_checkpoint(self._state)
223
+
224
+ if self._state.stepper.current_step >= self._state.stepper.total_steps:
225
+ self._state.dist_context.logger.info("Already ran, will do nothing")
226
+ return
227
+
228
+ self._state.dist_context.wait_world()
229
+
230
+ with (
231
+ tqdm(
232
+ desc="Inference",
233
+ total=self._state.stepper.total_steps,
234
+ disable=not self._state.dist_context.is_local_main_process,
235
+ initial=self._state.stepper.current_step
236
+ ) as bar,
237
+ self._state.garbage_collector as gc,
238
+ self._state.profiler.open() as profiler
239
+ ):
240
+ self._state.timeout_manager.step()
241
+
242
+ for batch_group in self._state.data_loader:
243
+ for batch in batch_group:
244
+ self._state.task_operator.forward(batch)
245
+
246
+ gc.collect_periodic()
247
+ self._state.stepper.step()
248
+ bar.update()
249
+
250
+ # checkpoint at the end of the step
251
+ self._state.checkpointer.checkpoint_if_needed(self._state)
252
+
253
+ if profiler:
254
+ profiler.step()
255
+
256
+ self._state.task.finalize(FinalizeContext())
d9d/loop/run/train.py CHANGED
@@ -121,7 +121,7 @@ class TrainingConfigurator:
121
121
  config_model=self._parameters.model_stage_factory,
122
122
  config_pipelining=self._parameters.pipelining,
123
123
  batch_maths=batch_maths,
124
- loss_computer=loss_computer
124
+ pipeline_callback=loss_computer
125
125
  ).build_pipeline_and_modules()
126
126
 
127
127
  metrics = ComposeMetric(task.create_metrics(CreateMetricsContext()).metrics)
@@ -130,9 +130,7 @@ class TrainingConfigurator:
130
130
  task_operator = TrainTaskOperator(
131
131
  dist_context=dist_context,
132
132
  task=task,
133
- pp_schedule=schedule,
134
- tracked_modules=modules,
135
- loss_computer=loss_computer,
133
+ pipeline=schedule,
136
134
  pipeline_state=pipeline_state_handler,
137
135
  metrics=metrics
138
136
  )
d9d/loop/state.py CHANGED
@@ -10,6 +10,7 @@ from d9d.loop.component import (
10
10
  BatchMaths,
11
11
  GradientClipper,
12
12
  GradientManager,
13
+ InferenceTaskOperator,
13
14
  JobLogger,
14
15
  JobProfiler,
15
16
  ManualGarbageCollector,
@@ -123,14 +124,16 @@ class TrainJobState(JobState):
123
124
 
124
125
 
125
126
  @dataclasses.dataclass(kw_only=True)
126
- class InferJobState(JobState):
127
+ class InferenceJobState(JobState):
127
128
  """
128
129
  Container for the state of an inference job.
129
130
 
130
131
  Attributes:
131
132
  task: The specific inference task logic definition.
133
+ task_operator: Executor for running forward and backward passes.
132
134
  """
133
135
  task: InferenceTask
136
+ task_operator: InferenceTaskOperator
134
137
 
135
138
  def state_dict(self) -> dict[str, Any]:
136
139
  return {
@@ -9,9 +9,12 @@ from .module import (
9
9
  )
10
10
  from .schedule import PipelineSchedule
11
11
  from .sharding import PipelineShardingSpec
12
+ from .types import PipelineLossFn, PipelineResultFn
12
13
 
13
14
  __all__ = [
14
15
  "ModuleSupportsPipelining",
16
+ "PipelineLossFn",
17
+ "PipelineResultFn",
15
18
  "PipelineSchedule",
16
19
  "PipelineShardingSpec",
17
20
  "PipelineStageInfo",
@@ -0,0 +1,28 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ PipelineResultFn = Callable[[dict[str, torch.Tensor], int], Any]
7
+ """
8
+ Callback function type for handling results from a final pipeline stage.
9
+
10
+ Args:
11
+ outputs: A dictionary mapping output names to tensors produced by the stage.
12
+ microbatch_idx: The index of the current micro-batch being processed.
13
+
14
+ Returns:
15
+ Anything - not used.
16
+ """
17
+
18
+ PipelineLossFn = Callable[[dict[str, torch.Tensor], int], torch.Tensor]
19
+ """
20
+ Callback function type for calculating loss in the final pipeline stage.
21
+
22
+ Args:
23
+ outputs: A dictionary mapping output names to tensors produced by the model.
24
+ microbatch_idx: The index of the current micro-batch being processed.
25
+
26
+ Returns:
27
+ The computed loss tensor (scalar).
28
+ """
@@ -1,19 +1,19 @@
1
1
  import dataclasses
2
2
  from collections.abc import Callable
3
3
 
4
- import torch
5
4
  from torch import nn
6
5
 
7
6
  from ...core.dist_context import REGULAR_DOMAIN, DistributedContext
8
- from ..api import PipelineSchedule, PipelineStageInfo
7
+ from ..api import PipelineLossFn, PipelineResultFn, PipelineSchedule, PipelineStageInfo
9
8
  from ..infra.schedule.component.program import (
10
9
  build_stage_to_host_rank_topology,
11
10
  invert_stage_to_host_rank_topology,
12
11
  )
13
- from ..infra.schedule.component.runtime import PipelineScheduleExecutor
12
+ from ..infra.schedule.component.runtime import OfflinePipelineExecutor, PipelineScheduleExecutor
14
13
  from ..infra.stage import PipelineStage
15
14
  from .config import (
16
15
  AnyPipelineScheduleConfig,
16
+ PipelineScheduleInferenceConfig,
17
17
  )
18
18
  from .registry import PIPELINE_PROGRAM_REGISTRY
19
19
 
@@ -27,38 +27,31 @@ class PipelineScheduleInfo:
27
27
  has_last_stage: bool
28
28
 
29
29
 
30
- def build_schedule(
31
- dist_context: DistributedContext,
32
- n_microbatches: int,
30
+ def _build_schedule_local(
33
31
  schedule_config: AnyPipelineScheduleConfig,
34
32
  model_provider: Callable[[PipelineStageInfo], nn.Module],
35
- loss_fn: Callable[[dict[str, torch.Tensor], int], torch.Tensor] | None,
33
+ callback: PipelineLossFn | PipelineResultFn
36
34
  ) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
37
- """
38
- Constructs the pipeline schedule and instantiates model stages.
35
+ stage_info = PipelineStageInfo(num_stages=1, current_stage=0)
39
36
 
40
- This function coordinates the creation of the distributed pipeline. It:
41
- 1. Selects the appropriate `PipelineProgramBuilder` based on the config.
42
- 2. Calculates the global stage topology mapping stages to ranks.
43
- 3. Instantiates the local model stages for the current rank using `model_provider`.
44
- 4. Wraps models in `PipelineStage` containers.
45
- 5. Generates the execution program (action list).
46
- 6. Builds the runtime executor.
37
+ model = model_provider(stage_info)
38
+ has_backward = not isinstance(schedule_config, PipelineScheduleInferenceConfig)
39
+ scheduler = OfflinePipelineExecutor(model=model, callback=callback, do_backward=has_backward)
47
40
 
48
- Args:
49
- dist_context: The distributed context.
50
- n_microbatches: Number of microbatches per global step.
51
- schedule_config: Configuration object determining the schedule strategy.
52
- model_provider: A factory function that accepts stage info and returns an `nn.Module`
53
- for that specific stage.
54
- loss_fn: Optional loss function. Required if training (backward pass needed).
41
+ return PipelineScheduleInfo(
42
+ schedule=scheduler,
43
+ has_first_stage=True,
44
+ has_last_stage=True
45
+ ), [model]
55
46
 
56
- Returns:
57
- A tuple containing:
58
- 1. `PipelineScheduleInfo`: The executable schedule and metadata.
59
- 2. `list[nn.Module]`: The local PyTorch modules created for this rank.
60
- """
61
47
 
48
+ def _build_schedule_distributed(
49
+ dist_context: DistributedContext,
50
+ n_microbatches: int,
51
+ schedule_config: AnyPipelineScheduleConfig,
52
+ model_provider: Callable[[PipelineStageInfo], nn.Module],
53
+ callback: PipelineLossFn | PipelineResultFn,
54
+ ) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
62
55
  program_builder = PIPELINE_PROGRAM_REGISTRY.program_for(schedule_config)
63
56
  mesh = dist_context.mesh_for(REGULAR_DOMAIN)["pp"]
64
57
 
@@ -103,7 +96,7 @@ def build_schedule(
103
96
  dist_context=dist_context,
104
97
  stages=stages,
105
98
  num_microbatches=n_microbatches,
106
- loss_fn=loss_fn,
99
+ callback=callback,
107
100
  program=program
108
101
  )
109
102
 
@@ -112,3 +105,49 @@ def build_schedule(
112
105
  has_first_stage=has_first_stage,
113
106
  has_last_stage=has_last_stage
114
107
  ), modules
108
+
109
+
110
+ def build_schedule(
111
+ dist_context: DistributedContext,
112
+ n_microbatches: int,
113
+ schedule_config: AnyPipelineScheduleConfig,
114
+ model_provider: Callable[[PipelineStageInfo], nn.Module],
115
+ callback: PipelineLossFn | PipelineResultFn,
116
+ ) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
117
+ """
118
+ Constructs the pipeline schedule and instantiates model stages.
119
+
120
+ This function coordinates the creation of the pipeline. If the context is
121
+ distributed, it builds a parallel schedule (`PipelineScheduleExecutor`) by
122
+ calculating topology and creating stages for the current rank. If the
123
+ context is local, it builds an offline schedule (`OfflinePipelineExecutor`)
124
+ for direct execution.
125
+
126
+ Args:
127
+ dist_context: The distributed context.
128
+ n_microbatches: Number of microbatches per global step.
129
+ schedule_config: Configuration object determining the schedule strategy.
130
+ model_provider: A factory function that accepts stage info and returns an `nn.Module`
131
+ for that specific stage.
132
+ callback: Callback either computing loss function (if training) or just processing pipeline outputs
133
+ (if not training).
134
+
135
+ Returns:
136
+ A tuple containing the schedule info (executor and metadata) and a list
137
+ of local PyTorch modules created for this rank.
138
+ """
139
+
140
+ if dist_context.mesh_params.is_distributed:
141
+ return _build_schedule_distributed(
142
+ dist_context=dist_context,
143
+ n_microbatches=n_microbatches,
144
+ schedule_config=schedule_config,
145
+ model_provider=model_provider,
146
+ callback=callback
147
+ )
148
+ else:
149
+ return _build_schedule_local(
150
+ schedule_config=schedule_config,
151
+ model_provider=model_provider,
152
+ callback=callback
153
+ )
@@ -14,6 +14,7 @@ from .action import (
14
14
  ForwardSendAction,
15
15
  )
16
16
  from .executor import PipelineScheduleExecutor
17
+ from .offline import OfflinePipelineExecutor
17
18
 
18
19
  __all__ = [
19
20
  "ActionBase",
@@ -25,5 +26,6 @@ __all__ = [
25
26
  "ForwardComputeAction",
26
27
  "ForwardReceiveAction",
27
28
  "ForwardSendAction",
28
- "PipelineScheduleExecutor",
29
+ "OfflinePipelineExecutor",
30
+ "PipelineScheduleExecutor"
29
31
  ]
@@ -7,8 +7,8 @@ import torch
7
7
 
8
8
  from d9d.pipelining.infra.stage import PipelineStage
9
9
 
10
+ from .callback import PipelineLossHandler, PipelineResultHandler
10
11
  from .communications import PipelineCommunicationHandler
11
- from .loss import PipelineLossHandler
12
12
 
13
13
 
14
14
  @dataclasses.dataclass(kw_only=True, slots=True)
@@ -21,7 +21,7 @@ class ActionContext:
21
21
  pipeline_kwargs_microbatches: The global keyword arguments sharded by microbatch.
22
22
  stages: A mapping of stage indices to their active PipelineStage instances.
23
23
  communications: The handler for P2P communications.
24
- loss: The handler for loss computation, or None if not available.
24
+ callback: The handler for either loss computation or result processing.
25
25
  """
26
26
 
27
27
  pipeline_inputs_microbatches: tuple[dict[str, torch.Tensor], ...]
@@ -29,7 +29,7 @@ class ActionContext:
29
29
 
30
30
  stages: dict[int, PipelineStage]
31
31
  communications: PipelineCommunicationHandler
32
- loss: PipelineLossHandler | None
32
+ callback: PipelineLossHandler | PipelineResultHandler
33
33
 
34
34
 
35
35
  class ActionWorkType(StrEnum):
@@ -208,7 +208,6 @@ class ForwardComputeAction(ActionBase):
208
208
  microbatch_idx: int
209
209
 
210
210
  def apply(self, ctx: ActionContext):
211
- # todo check unsharded
212
211
  stage = ctx.stages[self.stage_idx]
213
212
 
214
213
  if not stage.info.is_current_stage_first and self.stage_idx - 1 not in ctx.stages:
@@ -221,8 +220,8 @@ class ForwardComputeAction(ActionBase):
221
220
  )
222
221
  result = stage.get_local_fwd_output(self.microbatch_idx)
223
222
 
224
- if stage.info.is_current_stage_last and ctx.loss is not None:
225
- ctx.loss.compute_loss(result, self.microbatch_idx)
223
+ if stage.info.is_current_stage_last:
224
+ ctx.callback.trigger(result, self.microbatch_idx)
226
225
 
227
226
  if not stage.info.is_current_stage_last and self.stage_idx + 1 in ctx.stages:
228
227
  ctx.stages[self.stage_idx + 1].set_local_fwd_input(
@@ -260,14 +259,13 @@ class BackwardFullInputComputeAction(ActionBase):
260
259
  full_backward: bool
261
260
 
262
261
  def apply(self, ctx: ActionContext):
263
- # todo unshard
264
262
  stage = ctx.stages[self.stage_idx]
265
263
 
266
264
  if not stage.info.is_current_stage_last and self.stage_idx + 1 not in ctx.stages:
267
265
  ctx.communications.wait_bwd_recv(self.stage_idx, self.microbatch_idx)
268
266
 
269
- if stage.info.is_current_stage_last and ctx.loss is not None:
270
- loss = ctx.loss.acquire_loss(self.microbatch_idx)
267
+ if stage.info.is_current_stage_last and isinstance(ctx.callback, PipelineLossHandler):
268
+ loss = ctx.callback.acquire_loss(self.microbatch_idx)
271
269
  else:
272
270
  loss = None
273
271
 
@@ -310,7 +308,6 @@ class BackwardWeightComputeAction(ActionBase):
310
308
  microbatch_idx: int
311
309
 
312
310
  def apply(self, ctx: ActionContext):
313
- # todo unshard
314
311
  stage = ctx.stages[self.stage_idx]
315
312
 
316
313
  stage.backward_weight_one_chunk(
@@ -1,14 +1,41 @@
1
- from collections.abc import Callable
2
-
3
1
  import torch
4
2
 
5
- LossFn = Callable[[dict[str, torch.Tensor], int], torch.Tensor]
3
+ from d9d.pipelining.api import PipelineLossFn, PipelineResultFn
4
+
5
+
6
+ class PipelineResultHandler:
7
+ """
8
+ Wraps a callback function to handle results from pipeline execution.
9
+ """
10
+
11
+ def __init__(self, callback_fn: PipelineResultFn):
12
+ """
13
+ Constructs PipelineResultHandler object.
14
+
15
+ Args:
16
+ callback_fn: The function called with results.
17
+ """
18
+
19
+ self._callback_fn = callback_fn
20
+
21
+ def trigger(self, forward_result: dict[str, torch.Tensor], microbatch_index: int):
22
+ """
23
+ Invokes the underlying callback with the provided results.
24
+
25
+ Args:
26
+ forward_result: Dictionary of output tensors from the pipeline.
27
+ microbatch_index: The index of the current micro-batch.
28
+ """
29
+
30
+ self._callback_fn(forward_result, microbatch_index)
6
31
 
7
32
 
8
33
  class PipelineLossHandler:
9
- """Manages loss computation and state caching across forward and backward passes."""
34
+ """
35
+ Manages loss computation and state caching across forward and backward passes.
36
+ """
10
37
 
11
- def __init__(self, loss_fn: LossFn):
38
+ def __init__(self, loss_fn: PipelineLossFn):
12
39
  """
13
40
  Constructs the loss handler.
14
41
 
@@ -19,21 +46,17 @@ class PipelineLossHandler:
19
46
  self._loss_fn = loss_fn
20
47
  self._cached_values: dict[int, torch.Tensor] = {}
21
48
 
22
- def compute_loss(self, forward_result: dict[str, torch.Tensor], microbatch_index: int) -> torch.Tensor:
49
+ def trigger(self, forward_result: dict[str, torch.Tensor], microbatch_index: int):
23
50
  """
24
51
  Computes loss for a given microbatch result and caches it.
25
52
 
26
53
  Args:
27
54
  forward_result: The output from the last stage of the model.
28
55
  microbatch_index: The index of the microbatch being processed.
29
-
30
- Returns:
31
- The computed loss tensor.
32
56
  """
33
57
 
34
58
  result = self._loss_fn(forward_result, microbatch_index)
35
59
  self._cached_values[microbatch_index] = result
36
- return result
37
60
 
38
61
  def acquire_loss(self, microbatch_index: int) -> torch.Tensor:
39
62
  """