d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from collections.abc import Iterator, Mapping, Sequence
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.autograd.graph import Node
|
|
8
|
+
|
|
9
|
+
from .splitgrad import (
|
|
10
|
+
ParamGroup,
|
|
11
|
+
stage_backward_full,
|
|
12
|
+
stage_backward_input,
|
|
13
|
+
stage_backward_weight,
|
|
14
|
+
)
|
|
15
|
+
from .struct_helper import DictFlattener
|
|
16
|
+
|
|
17
|
+
# TODO/NOTICE: We WILL NOT disable FSDP's resharding for microbatches since it will modify
|
|
18
|
+
# TODO/NOTICE: its behavior in an unexpected way. Perhaps we need better FSDP resharding policy handler?
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass(slots=True)
|
|
22
|
+
class ForwardCache:
|
|
23
|
+
"""
|
|
24
|
+
Stores the inputs and outputs of a forward pass to be used later in the backward pass.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
inputs: dict[str, torch.Tensor]
|
|
28
|
+
outputs: dict[str, torch.Tensor]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ForwardComputeHandler:
|
|
32
|
+
"""
|
|
33
|
+
Handles the execution of the forward pass for a pipeline stage module.
|
|
34
|
+
|
|
35
|
+
Maintains a cache of inputs and outputs indexed by microbatch ID.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
stage_index: int,
|
|
41
|
+
module: nn.Module
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Constructs a ForwardComputeHandler object.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
stage_index: Logical index of the stage.
|
|
48
|
+
module: The PyTorch module representing this stage computation.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
self._stage_idx = stage_index
|
|
52
|
+
self._module = module
|
|
53
|
+
|
|
54
|
+
self._cache: dict[int, ForwardCache] = {}
|
|
55
|
+
|
|
56
|
+
def run(
|
|
57
|
+
self,
|
|
58
|
+
microbatch_index: int,
|
|
59
|
+
inputs: dict[str, torch.Tensor],
|
|
60
|
+
kwargs: dict[str, Any]
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Executes the module's forward pass.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
microbatch_index: Identifier for the current microbatch.
|
|
67
|
+
inputs: Dictionary of input tensors.
|
|
68
|
+
kwargs: Additional keyword arguments for the module.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The output of the module.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
RuntimeError: If the forward pass implementation fails.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# Compute forward
|
|
78
|
+
try:
|
|
79
|
+
output = self._module(**inputs, **kwargs)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise RuntimeError(f"S{self._stage_idx}B{microbatch_index} failed to run forward") from e
|
|
82
|
+
|
|
83
|
+
if not isinstance(output, Mapping):
|
|
84
|
+
raise ValueError("Currently, pipelined models should output dict[str, torch.Tensor | None]")
|
|
85
|
+
|
|
86
|
+
output = {k: v for k, v in output.items() if v is not None}
|
|
87
|
+
|
|
88
|
+
self._cache[microbatch_index] = ForwardCache(
|
|
89
|
+
inputs=inputs,
|
|
90
|
+
outputs=output
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def get_outputs(self, microbatch_index: int) -> dict[str, torch.Tensor]:
|
|
94
|
+
"""
|
|
95
|
+
Retrieves cached outputs for a specific microbatch without removing them.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
microbatch_index: Identifier for the microbatch.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Dictionary of output tensors.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
return self._cache[microbatch_index].outputs
|
|
105
|
+
|
|
106
|
+
def pop_inputs_outputs(
|
|
107
|
+
self, microbatch_index: int
|
|
108
|
+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
|
109
|
+
"""
|
|
110
|
+
Retrieves and removes the cached inputs and outputs for a specific microbatch.
|
|
111
|
+
|
|
112
|
+
Typically called when initiating the backward pass.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
microbatch_index: Identifier for the microbatch.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
A tuple containing (inputs, outputs).
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
cache = self._cache.pop(microbatch_index)
|
|
122
|
+
return cache.inputs, cache.outputs
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
126
|
+
class BackwardCacheInputForWeight:
|
|
127
|
+
"""
|
|
128
|
+
State preserved after calculating input gradients, pending weight gradient calculation.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
inputs_grad: dict[str, torch.Tensor]
|
|
132
|
+
param_groups: list[ParamGroup]
|
|
133
|
+
ownership_tokens: list[Node]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
137
|
+
class BackwardCacheInputForFull:
|
|
138
|
+
stage_outputs_or_loss: list[torch.Tensor]
|
|
139
|
+
output_grads: list[torch.Tensor] | None
|
|
140
|
+
input_values: list[torch.Tensor]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
144
|
+
class BackwardCacheFull:
|
|
145
|
+
"""
|
|
146
|
+
State preserved after calculating weight gradients.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
inputs_grad: dict[str, torch.Tensor | None]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class BackwardComputeHandler:
|
|
153
|
+
"""
|
|
154
|
+
Handles the execution of backward passes for a pipeline stage.
|
|
155
|
+
|
|
156
|
+
Supports splitting the backward pass into input-gradients and weight-gradients
|
|
157
|
+
phases, which is necessary for schedules like ZB.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
stage_index: int,
|
|
163
|
+
module: nn.Module
|
|
164
|
+
):
|
|
165
|
+
"""
|
|
166
|
+
Constructs a BackwardComputeHandler object.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
stage_index: Logical index of the stage.
|
|
170
|
+
module: The PyTorch module to compute gradients for.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
self._stage_idx = stage_index
|
|
174
|
+
self._module = module
|
|
175
|
+
|
|
176
|
+
self._cache: dict[int, BackwardCacheInputForWeight | BackwardCacheInputForFull | BackwardCacheFull] = {}
|
|
177
|
+
|
|
178
|
+
def _parameters_with_grad(self) -> Iterator[nn.Parameter]:
|
|
179
|
+
return (param for param in self._module.parameters() if param.requires_grad)
|
|
180
|
+
|
|
181
|
+
def backward_full(
|
|
182
|
+
self,
|
|
183
|
+
microbatch_index: int,
|
|
184
|
+
inputs: dict[str, torch.Tensor],
|
|
185
|
+
outputs: dict[str, torch.Tensor],
|
|
186
|
+
outputs_grad: dict[str, torch.Tensor] | None,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Performs a full backward pass (both inputs and weights).
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
microbatch_index: Identifier for the microbatch.
|
|
193
|
+
inputs: The inputs used in the forward pass.
|
|
194
|
+
outputs: The outputs produced by the forward pass.
|
|
195
|
+
outputs_grad: Gradients of the loss with respect to the outputs.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
if microbatch_index in self._cache:
|
|
199
|
+
raise ValueError(f"S{self._stage_idx}B{microbatch_index} double backward")
|
|
200
|
+
|
|
201
|
+
inputs_flattener = DictFlattener(inputs.keys())
|
|
202
|
+
outputs_flattener = DictFlattener(outputs.keys())
|
|
203
|
+
|
|
204
|
+
inputs_grad_linear = stage_backward_full(
|
|
205
|
+
outputs=outputs_flattener.flatten(outputs),
|
|
206
|
+
output_grads=outputs_flattener.flatten(outputs_grad) if outputs_grad is not None else None,
|
|
207
|
+
inputs=inputs_flattener.flatten(inputs)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if self._stage_idx != 0:
|
|
211
|
+
self._cache[microbatch_index] = BackwardCacheFull(
|
|
212
|
+
inputs_grad=inputs_flattener.unflatten(inputs_grad_linear)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def backward_input(
|
|
216
|
+
self,
|
|
217
|
+
microbatch_index: int,
|
|
218
|
+
inputs: dict[str, torch.Tensor],
|
|
219
|
+
outputs: dict[str, torch.Tensor],
|
|
220
|
+
outputs_grad: dict[str, torch.Tensor] | None
|
|
221
|
+
):
|
|
222
|
+
"""
|
|
223
|
+
Performs a partial backward pass to compute gradients with respect to inputs only.
|
|
224
|
+
|
|
225
|
+
This prepares the computation state for a subsequent `backward_weight` call.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
microbatch_index: Identifier for the microbatch.
|
|
229
|
+
inputs: The inputs used in the forward pass.
|
|
230
|
+
outputs: The outputs produced by the forward pass.
|
|
231
|
+
outputs_grad: Gradients of the loss with respect to the outputs.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
if microbatch_index in self._cache:
|
|
235
|
+
raise ValueError("Double backward pass")
|
|
236
|
+
|
|
237
|
+
inputs_flattener = DictFlattener(inputs.keys())
|
|
238
|
+
outputs_flattener = DictFlattener(outputs.keys())
|
|
239
|
+
|
|
240
|
+
if self._stage_idx == 0:
|
|
241
|
+
self._cache[microbatch_index] = BackwardCacheInputForFull(
|
|
242
|
+
stage_outputs_or_loss=outputs_flattener.flatten(outputs),
|
|
243
|
+
output_grads=outputs_flattener.flatten(outputs_grad) if outputs_grad is not None else None,
|
|
244
|
+
input_values=inputs_flattener.flatten(inputs)
|
|
245
|
+
)
|
|
246
|
+
else:
|
|
247
|
+
results = stage_backward_input(
|
|
248
|
+
outputs=outputs_flattener.flatten(outputs),
|
|
249
|
+
output_grads=outputs_flattener.flatten(outputs_grad) if outputs_grad is not None else None,
|
|
250
|
+
inputs=inputs_flattener.flatten(inputs),
|
|
251
|
+
weights=self._parameters_with_grad()
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
self._cache[microbatch_index] = BackwardCacheInputForWeight(
|
|
255
|
+
inputs_grad=inputs_flattener.unflatten(cast(Sequence[torch.Tensor], results.input_grads)),
|
|
256
|
+
param_groups=results.param_groups,
|
|
257
|
+
ownership_tokens=results.grad_ownership_tokens
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def backward_weight(
|
|
261
|
+
self,
|
|
262
|
+
microbatch_index: int
|
|
263
|
+
):
|
|
264
|
+
"""
|
|
265
|
+
Performs a partial backward pass to accumulate gradients into weights.
|
|
266
|
+
|
|
267
|
+
Must be preceded by `backward_input` for the same microbatch index.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
microbatch_index: Identifier for the microbatch.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
if microbatch_index not in self._cache:
|
|
274
|
+
raise ValueError(f"S{self._stage_idx}BW{microbatch_index} - weight backward with no input backward before")
|
|
275
|
+
|
|
276
|
+
prev_cache = self._cache.pop(microbatch_index)
|
|
277
|
+
|
|
278
|
+
match prev_cache:
|
|
279
|
+
case BackwardCacheInputForFull():
|
|
280
|
+
stage_backward_full(
|
|
281
|
+
outputs=prev_cache.stage_outputs_or_loss,
|
|
282
|
+
output_grads=prev_cache.output_grads,
|
|
283
|
+
inputs=prev_cache.input_values
|
|
284
|
+
)
|
|
285
|
+
case BackwardCacheInputForWeight():
|
|
286
|
+
stage_backward_weight(
|
|
287
|
+
weights=self._parameters_with_grad(),
|
|
288
|
+
param_groups=prev_cache.param_groups
|
|
289
|
+
)
|
|
290
|
+
case _:
|
|
291
|
+
raise ValueError("Previous backward was not input backward")
|
|
292
|
+
|
|
293
|
+
def pop_for_sending(self, microbatch_index: int) -> dict[str, torch.Tensor]:
|
|
294
|
+
"""
|
|
295
|
+
Retrieves the calculated input gradients for a microbatch.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
microbatch_index: Identifier for the microbatch.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Dictionary of gradient tensors.
|
|
302
|
+
"""
|
|
303
|
+
cached = self._cache[microbatch_index]
|
|
304
|
+
|
|
305
|
+
match cached:
|
|
306
|
+
case BackwardCacheFull():
|
|
307
|
+
del self._cache[microbatch_index]
|
|
308
|
+
case BackwardCacheInputForWeight():
|
|
309
|
+
pass
|
|
310
|
+
case _:
|
|
311
|
+
raise ValueError("You should call either backward_full or backward_input before popping cached grad")
|
|
312
|
+
|
|
313
|
+
for grad_value in cached.inputs_grad.values():
|
|
314
|
+
if grad_value is None:
|
|
315
|
+
raise ValueError("Cannot pop null gradient for sending! Perhaps malformed schedule?")
|
|
316
|
+
|
|
317
|
+
return cast(dict[str, torch.Tensor], cached.inputs_grad)
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
from collections import defaultdict, deque
|
|
2
|
+
from collections.abc import Callable, Iterator
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.autograd.graph import GradientEdge, Node
|
|
9
|
+
|
|
10
|
+
from d9d.core.autograd import GLOBAL_GRAD_CONTEXT, GradDirection
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def stage_backward_full(
|
|
14
|
+
outputs: list[torch.Tensor],
|
|
15
|
+
output_grads: list[torch.Tensor] | None,
|
|
16
|
+
inputs: list[torch.Tensor]
|
|
17
|
+
) -> list[torch.Tensor | None]:
|
|
18
|
+
"""
|
|
19
|
+
Performs a standard, full backward pass for a pipeline stage.
|
|
20
|
+
|
|
21
|
+
This function computes gradients for the inputs based on the gradients
|
|
22
|
+
received for the outputs.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
outputs: The output tensors of the forward pass.
|
|
26
|
+
output_grads: The gradients arriving from the next pipeline stage corresponding
|
|
27
|
+
to `outputs`. If None, assumes scalar output or implied ones.
|
|
28
|
+
inputs: The input tensors to the forward pass for which gradients are required.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A list of gradients corresponding to the `inputs`. If some input does not require gradient - its result will
|
|
32
|
+
be None.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
with GLOBAL_GRAD_CONTEXT.with_directions(GradDirection.inputs, GradDirection.weight):
|
|
36
|
+
torch.autograd.backward(
|
|
37
|
+
tensors=outputs,
|
|
38
|
+
grad_tensors=output_grads
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
input_grads = []
|
|
42
|
+
for input_item in inputs:
|
|
43
|
+
input_grads.append(input_item.grad)
|
|
44
|
+
input_item.grad = None
|
|
45
|
+
return input_grads
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ParamGroup:
|
|
50
|
+
"""
|
|
51
|
+
Represents a group of parameters and their dependency intermediates in the autograd graph.
|
|
52
|
+
|
|
53
|
+
This structure is used to manage the split backward pass, identifying which
|
|
54
|
+
intermediate nodes in the graph allow gradients to flow to specific sets of parameters.
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
params: Set of autograd Nodes representing the parameters.
|
|
58
|
+
intermediates: List of autograd Nodes serving as entry points for gradients
|
|
59
|
+
flowing to these parameters.
|
|
60
|
+
grads: Storage for captured gradients at the intermediate nodes during
|
|
61
|
+
the input backward phase.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
params: set[Node]
|
|
65
|
+
intermediates: list[Node] | None
|
|
66
|
+
grads: list[torch.Tensor | None] | None = None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
|
|
70
|
+
if t.requires_grad and t.grad_fn is None:
|
|
71
|
+
# hack from pytorch codebase to create accumulation op
|
|
72
|
+
viewed_t = t.view_as(t)
|
|
73
|
+
grad_fn = viewed_t.grad_fn
|
|
74
|
+
grad_fn = cast(Node, grad_fn)
|
|
75
|
+
return grad_fn.next_functions[0][0]
|
|
76
|
+
else:
|
|
77
|
+
return t.grad_fn
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]:
|
|
81
|
+
"""
|
|
82
|
+
Builds a reverse adjacency list (Input -> Output) via BFS from the roots.
|
|
83
|
+
|
|
84
|
+
Standard autograd graphs point from Output -> Input (next_functions).
|
|
85
|
+
This helper provides the reverse mapping to assist in dependency analysis.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
roots: The starting nodes for the graph traversal.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A dictionary mapping a node to a list of its dependent (child) nodes.
|
|
92
|
+
"""
|
|
93
|
+
reverse_graph = defaultdict(list)
|
|
94
|
+
valid_roots = {x for x in roots if x is not None}
|
|
95
|
+
to_visit = deque(valid_roots)
|
|
96
|
+
visited = set(valid_roots)
|
|
97
|
+
|
|
98
|
+
while to_visit:
|
|
99
|
+
current_node = to_visit.popleft()
|
|
100
|
+
for parent_node, _ in current_node.next_functions:
|
|
101
|
+
if parent_node is None:
|
|
102
|
+
continue
|
|
103
|
+
reverse_graph[parent_node].append(current_node)
|
|
104
|
+
if parent_node not in visited:
|
|
105
|
+
visited.add(parent_node)
|
|
106
|
+
to_visit.append(parent_node)
|
|
107
|
+
|
|
108
|
+
return reverse_graph
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _reverse_closure(
|
|
112
|
+
roots: list[Node], target_nodes: set[Node], reverse_edges_dict: dict[Node, list[Node]]
|
|
113
|
+
) -> tuple[set[Node], set[Node]]:
|
|
114
|
+
"""
|
|
115
|
+
Computes a closure of nodes reachable from roots in the reverse graph.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
roots: Starting nodes.
|
|
119
|
+
target_nodes: Nodes that act as boundaries/targets for the search.
|
|
120
|
+
reverse_edges_dict: The reverse graph adjacency list.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A tuple containing the set of all closure nodes and the set of visited target nodes.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
closure: set[Node] = set()
|
|
127
|
+
visited_target_nodes = set()
|
|
128
|
+
to_visit: deque[Node] = deque()
|
|
129
|
+
|
|
130
|
+
for node in roots:
|
|
131
|
+
if node is not None and node not in closure:
|
|
132
|
+
closure.add(node)
|
|
133
|
+
to_visit.append(node)
|
|
134
|
+
|
|
135
|
+
while to_visit:
|
|
136
|
+
node = to_visit.popleft()
|
|
137
|
+
reverse_edges = reverse_edges_dict[node]
|
|
138
|
+
for fn in reverse_edges:
|
|
139
|
+
if fn in closure or fn is None:
|
|
140
|
+
continue
|
|
141
|
+
if fn in target_nodes:
|
|
142
|
+
visited_target_nodes.add(fn)
|
|
143
|
+
continue
|
|
144
|
+
closure.add(fn)
|
|
145
|
+
to_visit.append(fn)
|
|
146
|
+
|
|
147
|
+
return closure, visited_target_nodes
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _get_param_groups(
|
|
151
|
+
inputs: list[Node], params: list[Node], reverse_edges_dict: dict[Node, list[Node]]
|
|
152
|
+
) -> list[ParamGroup]:
|
|
153
|
+
"""
|
|
154
|
+
Clusters parameters based on their dependencies on inputs.
|
|
155
|
+
|
|
156
|
+
This function identifies how gradients propagate from inputs through intermediates
|
|
157
|
+
to parameters, grouping them to facilitate split backward execution.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
inputs: Gradient functions of the input tensors.
|
|
161
|
+
params: Gradient functions of the parameter tensors.
|
|
162
|
+
reverse_edges_dict: The reverse autograd graph.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
A list of distinct parameter groups.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
inputs_closure, _ = _reverse_closure(inputs, set(), reverse_edges_dict)
|
|
169
|
+
|
|
170
|
+
node_to_group_map: dict[Node, dict[str, set[Node]]] = {}
|
|
171
|
+
|
|
172
|
+
for param in params:
|
|
173
|
+
_, intersected_inputs = _reverse_closure(
|
|
174
|
+
[param], inputs_closure, reverse_edges_dict
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
current_dict = {
|
|
178
|
+
"params": {param},
|
|
179
|
+
"intermediates": intersected_inputs
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
target_dict = None
|
|
183
|
+
for intermediate_node in intersected_inputs:
|
|
184
|
+
if intermediate_node in node_to_group_map:
|
|
185
|
+
target_dict = node_to_group_map[intermediate_node]
|
|
186
|
+
break
|
|
187
|
+
|
|
188
|
+
if target_dict is not None:
|
|
189
|
+
target_dict["params"].update(current_dict["params"])
|
|
190
|
+
target_dict["intermediates"].update(current_dict["intermediates"])
|
|
191
|
+
current_dict = target_dict
|
|
192
|
+
|
|
193
|
+
for intermediate_node in current_dict["intermediates"]:
|
|
194
|
+
node_to_group_map[intermediate_node] = current_dict
|
|
195
|
+
|
|
196
|
+
# Deduplicate and Convert to Dataclass
|
|
197
|
+
unique_groups = []
|
|
198
|
+
seen_ids = set()
|
|
199
|
+
for group_dict in node_to_group_map.values():
|
|
200
|
+
if id(group_dict) not in seen_ids:
|
|
201
|
+
seen_ids.add(id(group_dict))
|
|
202
|
+
unique_groups.append(ParamGroup(
|
|
203
|
+
params=group_dict["params"],
|
|
204
|
+
intermediates=list(group_dict["intermediates"])
|
|
205
|
+
))
|
|
206
|
+
|
|
207
|
+
return unique_groups
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _make_capture_hook(group: ParamGroup, idx: int) -> Callable[[torch.Tensor], None]:
|
|
211
|
+
def _hook(grad_in: torch.Tensor):
|
|
212
|
+
# Lazy init gradients list
|
|
213
|
+
if group.grads is None and group.intermediates is not None:
|
|
214
|
+
group.grads = [None] * len(group.intermediates)
|
|
215
|
+
|
|
216
|
+
if group.grads is not None:
|
|
217
|
+
group.grads[idx] = grad_in
|
|
218
|
+
|
|
219
|
+
return _hook
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class BackwardInputResult:
|
|
224
|
+
"""
|
|
225
|
+
Container for the results of the input backward phase.
|
|
226
|
+
|
|
227
|
+
Attributes:
|
|
228
|
+
input_grads: The gradients computed for the input tensors.
|
|
229
|
+
param_groups: The parameter groups with hooks established to capture
|
|
230
|
+
weight gradients in the subsequent phase.
|
|
231
|
+
grad_ownership_tokens: References to tensors keeping the computation
|
|
232
|
+
graph alive for the weight backward phase.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
input_grads: list[torch.Tensor | None]
|
|
236
|
+
param_groups: list[ParamGroup]
|
|
237
|
+
grad_ownership_tokens: list[Any]
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def stage_backward_input(
|
|
241
|
+
outputs: list[torch.Tensor],
|
|
242
|
+
output_grads: list[torch.Tensor] | None,
|
|
243
|
+
inputs: list[torch.Tensor],
|
|
244
|
+
weights: Iterator[nn.Parameter],
|
|
245
|
+
) -> BackwardInputResult:
|
|
246
|
+
"""
|
|
247
|
+
Performs the first phase of a split backward pass: Input Gradients.
|
|
248
|
+
|
|
249
|
+
This function computes the gradients with respect to `inputs` while postponing
|
|
250
|
+
the computation of gradients with respect to `weights`. It analyzes the
|
|
251
|
+
autograd graph to identify intermediate nodes where gradients destined for
|
|
252
|
+
weights split off from the main flow. Hooks are registered at these
|
|
253
|
+
intermediates to capture gradients for the second phase (`stage_backward_weight`).
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
outputs: The output tensors of the forward pass.
|
|
257
|
+
output_grads: The gradients arriving for the outputs.
|
|
258
|
+
inputs: The input tensors from the forward pass.
|
|
259
|
+
weights: An iterator over the model parameters (weights).
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A result object containing input gradients, prepared parameter groups,
|
|
263
|
+
and ownership tokens to maintain graph validity.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
outputs_grad_fn = [grad_fn for x in outputs if (grad_fn := _get_grad_fn_or_grad_acc(x)) is not None]
|
|
267
|
+
inputs_grad_fn = [grad_fn for x in inputs if (grad_fn := _get_grad_fn_or_grad_acc(x)) is not None]
|
|
268
|
+
weights_grad_fn = [grad_fn for x in weights if (grad_fn := _get_grad_fn_or_grad_acc(x)) is not None]
|
|
269
|
+
|
|
270
|
+
reverse_edges = _construct_reverse_graph(outputs_grad_fn)
|
|
271
|
+
param_groups = _get_param_groups(inputs_grad_fn, weights_grad_fn, reverse_edges)
|
|
272
|
+
|
|
273
|
+
hook_handles = []
|
|
274
|
+
|
|
275
|
+
for group in param_groups:
|
|
276
|
+
if group.intermediates:
|
|
277
|
+
for i, node in enumerate(group.intermediates):
|
|
278
|
+
hook_handles.append(node.register_prehook(_make_capture_hook(group, i)))
|
|
279
|
+
|
|
280
|
+
if output_grads is None:
|
|
281
|
+
output_grads = [torch.ones_like(o) for o in outputs]
|
|
282
|
+
|
|
283
|
+
inputs_requiring_grad = [inp for inp in inputs if inp.requires_grad]
|
|
284
|
+
|
|
285
|
+
with GLOBAL_GRAD_CONTEXT.with_directions(GradDirection.inputs):
|
|
286
|
+
torch.autograd.backward(
|
|
287
|
+
tensors=outputs,
|
|
288
|
+
grad_tensors=output_grads,
|
|
289
|
+
inputs=inputs_requiring_grad,
|
|
290
|
+
retain_graph=True,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
final_input_grads = []
|
|
294
|
+
|
|
295
|
+
# 6. Cleanup
|
|
296
|
+
for input_item in inputs:
|
|
297
|
+
final_input_grads.append(input_item.grad)
|
|
298
|
+
input_item.grad = None
|
|
299
|
+
|
|
300
|
+
for handle in hook_handles:
|
|
301
|
+
handle.remove()
|
|
302
|
+
|
|
303
|
+
return BackwardInputResult(
|
|
304
|
+
input_grads=final_input_grads,
|
|
305
|
+
param_groups=param_groups,
|
|
306
|
+
# TODO(max): we can keep only intermediate ownership tokens to both truncate the
|
|
307
|
+
# TODO(max): graph and do not deallocate C++ stuff
|
|
308
|
+
grad_ownership_tokens=outputs # Keep the tensors alive!
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def stage_backward_weight( # noqa: C901
|
|
313
|
+
weights: Iterator[nn.Parameter],
|
|
314
|
+
param_groups: list[ParamGroup],
|
|
315
|
+
retain_graph: bool = False
|
|
316
|
+
) -> tuple[torch.Tensor | None, ...]:
|
|
317
|
+
"""
|
|
318
|
+
Performs the second phase of a split backward pass: Weight Gradients.
|
|
319
|
+
|
|
320
|
+
This function consumes the gradients captured in the `ParamGroup`s during
|
|
321
|
+
`stage_backward_input` to compute the final gradients for the model weights.
|
|
322
|
+
It triggers backward passes starting from the intermediate nodes identified previously.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
weights: An iterator over the model parameters to extract gradients for.
|
|
326
|
+
param_groups: The list of groups containing captured intermediate gradients.
|
|
327
|
+
retain_graph: Whether to retain the graph after this backward pass.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
A tuple of gradients corresponding to the provided `weights`.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
grad_acc_to_weight = {}
|
|
334
|
+
all_weights = [] # Keep order
|
|
335
|
+
|
|
336
|
+
for weight in weights:
|
|
337
|
+
all_weights.append(weight)
|
|
338
|
+
grad_acc = _get_grad_fn_or_grad_acc(weight)
|
|
339
|
+
if grad_acc is not None:
|
|
340
|
+
grad_acc_to_weight[grad_acc] = weight
|
|
341
|
+
|
|
342
|
+
for group in param_groups:
|
|
343
|
+
valid_edges = []
|
|
344
|
+
valid_grad_outputs: list[torch.Tensor] = []
|
|
345
|
+
|
|
346
|
+
# Ensure we have data
|
|
347
|
+
if group.grads and group.intermediates:
|
|
348
|
+
for grads_tuple, intermediate in zip(group.grads, group.intermediates, strict=True):
|
|
349
|
+
if grads_tuple is None:
|
|
350
|
+
raise ValueError("Trying to do backward_weight with to intermediate grads")
|
|
351
|
+
non_none = [g for g in grads_tuple if g is not None]
|
|
352
|
+
if len(non_none) > 0:
|
|
353
|
+
valid_edges.append(GradientEdge(intermediate, 0))
|
|
354
|
+
valid_grad_outputs.append(cast(torch.Tensor, sum(non_none)))
|
|
355
|
+
|
|
356
|
+
# Break Cycle: Intermediates
|
|
357
|
+
group.intermediates = None
|
|
358
|
+
|
|
359
|
+
if valid_edges:
|
|
360
|
+
inputs_for_backward = []
|
|
361
|
+
for node in group.params:
|
|
362
|
+
if node in grad_acc_to_weight:
|
|
363
|
+
inputs_for_backward.append(grad_acc_to_weight[node])
|
|
364
|
+
|
|
365
|
+
if inputs_for_backward:
|
|
366
|
+
with GLOBAL_GRAD_CONTEXT.with_directions(GradDirection.weight):
|
|
367
|
+
torch.autograd.backward(
|
|
368
|
+
tensors=valid_edges,
|
|
369
|
+
grad_tensors=valid_grad_outputs,
|
|
370
|
+
retain_graph=retain_graph,
|
|
371
|
+
inputs=inputs_for_backward
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Break Cycle: Grads
|
|
375
|
+
group.grads = None
|
|
376
|
+
|
|
377
|
+
return tuple(w.grad for w in all_weights)
|