d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch.distributed as dist
|
|
4
|
+
import torch.nn.utils
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.autograd.profiler import record_function
|
|
7
|
+
from torch.distributed import DeviceMesh
|
|
8
|
+
from torch.distributed.tensor import DTensor
|
|
9
|
+
|
|
10
|
+
from d9d.internals.grad_norm.group import ParametersForNorm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _reduce_op_from_norm_type(norm_type: float) -> dist.ReduceOp.RedOpType:
|
|
14
|
+
if math.isinf(norm_type):
|
|
15
|
+
return dist.ReduceOp.MAX
|
|
16
|
+
else:
|
|
17
|
+
return dist.ReduceOp.SUM
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _parameter_to_local_grad(parameter: nn.Parameter) -> torch.Tensor:
|
|
21
|
+
grad = parameter.grad
|
|
22
|
+
|
|
23
|
+
if grad is None:
|
|
24
|
+
raise ValueError("None grad detected")
|
|
25
|
+
|
|
26
|
+
if isinstance(grad, DTensor):
|
|
27
|
+
return grad.to_local()
|
|
28
|
+
else:
|
|
29
|
+
return grad
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_local_norm_pow(
|
|
33
|
+
parameters: list[nn.Parameter],
|
|
34
|
+
norm_type: float
|
|
35
|
+
) -> torch.Tensor:
|
|
36
|
+
# calculates for local
|
|
37
|
+
|
|
38
|
+
if len(parameters) == 0:
|
|
39
|
+
return torch.tensor(0.0, device="cuda")
|
|
40
|
+
|
|
41
|
+
norm_val = torch.nn.utils.get_total_norm(
|
|
42
|
+
[_parameter_to_local_grad(x) for x in parameters],
|
|
43
|
+
norm_type=norm_type,
|
|
44
|
+
foreach=True,
|
|
45
|
+
error_if_nonfinite=False
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if math.isinf(norm_type):
|
|
49
|
+
return norm_val
|
|
50
|
+
else:
|
|
51
|
+
return norm_val ** norm_type
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _get_global_norm_pow_horizontal(
|
|
55
|
+
parameter_groups: ParametersForNorm,
|
|
56
|
+
norm_type: float
|
|
57
|
+
) -> torch.Tensor:
|
|
58
|
+
# calculates for horizontal parallelism
|
|
59
|
+
if len(parameter_groups) == 0:
|
|
60
|
+
return torch.tensor(0.0, device="cuda")
|
|
61
|
+
|
|
62
|
+
norms: list[torch.Tensor] = []
|
|
63
|
+
works: list[dist.Work] = []
|
|
64
|
+
for group, group_params in parameter_groups.items():
|
|
65
|
+
local_norm_pow = _get_local_norm_pow(group_params, norm_type=norm_type)
|
|
66
|
+
if group.shard_meshes is not None:
|
|
67
|
+
if len(group.shard_meshes) != 1:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"Currently we do not support calculating norm for tensors that are sharded on multiple dims - feel "
|
|
70
|
+
"free to file an issue if you need it."
|
|
71
|
+
)
|
|
72
|
+
process_group = group.shard_meshes[0].get_group()
|
|
73
|
+
work = dist.all_reduce(
|
|
74
|
+
local_norm_pow,
|
|
75
|
+
op=_reduce_op_from_norm_type(norm_type),
|
|
76
|
+
group=process_group,
|
|
77
|
+
async_op=True
|
|
78
|
+
)
|
|
79
|
+
works.append(work)
|
|
80
|
+
norms.append(local_norm_pow)
|
|
81
|
+
|
|
82
|
+
for work in works:
|
|
83
|
+
work.wait()
|
|
84
|
+
|
|
85
|
+
norms_total = torch.stack(norms, dim=0)
|
|
86
|
+
|
|
87
|
+
if math.isinf(norm_type):
|
|
88
|
+
return norms_total.max()
|
|
89
|
+
else:
|
|
90
|
+
return norms_total.sum()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _get_global_norm_pow_pp(
|
|
94
|
+
parameter_groups: ParametersForNorm,
|
|
95
|
+
norm_type: float,
|
|
96
|
+
pp_mesh: DeviceMesh | None
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
norm = _get_global_norm_pow_horizontal(
|
|
99
|
+
parameter_groups=parameter_groups,
|
|
100
|
+
norm_type=norm_type
|
|
101
|
+
)
|
|
102
|
+
if pp_mesh is not None:
|
|
103
|
+
dist.all_reduce(norm, op=_reduce_op_from_norm_type(norm_type), group=pp_mesh.get_group())
|
|
104
|
+
return norm
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _clip_grad_with_norm_(
|
|
108
|
+
parameter_groups: ParametersForNorm,
|
|
109
|
+
max_norm: float,
|
|
110
|
+
total_norm: torch.Tensor
|
|
111
|
+
):
|
|
112
|
+
clip_coef = max_norm / (total_norm + 1e-6)
|
|
113
|
+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
|
114
|
+
|
|
115
|
+
for group in parameter_groups.values():
|
|
116
|
+
grads = [_parameter_to_local_grad(x) for x in group]
|
|
117
|
+
torch._foreach_mul_(grads, clip_coef_clamped)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def clip_grad_norm_distributed_(
|
|
121
|
+
parameter_groups: ParametersForNorm,
|
|
122
|
+
max_norm: float | None,
|
|
123
|
+
norm_type: float,
|
|
124
|
+
pp_mesh: DeviceMesh | None
|
|
125
|
+
) -> torch.Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Clips gradient norms in a fully distributed environment.
|
|
128
|
+
|
|
129
|
+
This function calculates the global gradient norm across all dimensions of parallelism
|
|
130
|
+
(Horizontal - DP/CP/TP/EP/..., and Pipeline) and scales the gradients in-place to ensure the norm
|
|
131
|
+
does not exceed max_norm.
|
|
132
|
+
|
|
133
|
+
It accurately handles DTensors by identifying their sharding placements and performing
|
|
134
|
+
reductions only on the necessary process groups.
|
|
135
|
+
|
|
136
|
+
Overlaps communication and computation if possible.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
parameter_groups: Dictionary grouping parameters by synchronization requirements,
|
|
140
|
+
typically created by `group_parameters_for_norm`.
|
|
141
|
+
max_norm: The maximum allowed norm of the gradients. If None, the function
|
|
142
|
+
calculates and returns the global norm without modifying the gradients.
|
|
143
|
+
norm_type: The type of the norm to calculate (e.g., 2.0 for L2 norm, inf for max norm).
|
|
144
|
+
pp_mesh: The device mesh representing the pipeline parallel dimension, needed
|
|
145
|
+
to reduce norms across pipeline stages.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
The calculated global gradient norm.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
with record_function("Gradient Clipping"):
|
|
152
|
+
global_norm_pow = _get_global_norm_pow_pp(
|
|
153
|
+
parameter_groups=parameter_groups,
|
|
154
|
+
norm_type=norm_type,
|
|
155
|
+
pp_mesh=pp_mesh
|
|
156
|
+
)
|
|
157
|
+
if math.isinf(norm_type):
|
|
158
|
+
global_norm = global_norm_pow
|
|
159
|
+
else:
|
|
160
|
+
global_norm = global_norm_pow ** (1.0 / norm_type)
|
|
161
|
+
|
|
162
|
+
if max_norm:
|
|
163
|
+
_clip_grad_with_norm_(
|
|
164
|
+
parameter_groups,
|
|
165
|
+
max_norm=max_norm,
|
|
166
|
+
total_norm=global_norm
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
return global_norm
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gradient synchronization utilities.
|
|
3
|
+
|
|
4
|
+
This package provides the infrastructure for manual gradient bucketing and
|
|
5
|
+
asynchronous reduction, similar to DistributedDataParallel but exposed
|
|
6
|
+
for internal framework usage with DTensors.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from .synchronizer import GradientSynchronizer
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"GradientSynchronizer"
|
|
14
|
+
]
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
from torch import Tensor, nn
|
|
7
|
+
from torch.autograd.profiler import record_function
|
|
8
|
+
from torch.distributed import DeviceMesh
|
|
9
|
+
from torch.distributed.tensor import DTensor
|
|
10
|
+
from torch.utils.hooks import RemovableHandle
|
|
11
|
+
|
|
12
|
+
from .placement_helper import dist_grad_from_local
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AbstractGradientBucket(abc.ABC):
|
|
16
|
+
"""
|
|
17
|
+
Interface for a bucket containing a subset of model parameters.
|
|
18
|
+
|
|
19
|
+
A bucket manages the memory layout and synchronization lifecycle of the
|
|
20
|
+
gradients associated with its parameters.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abc.abstractmethod
|
|
24
|
+
def bind(self):
|
|
25
|
+
"""
|
|
26
|
+
Initializes the bucket state.
|
|
27
|
+
|
|
28
|
+
This involves allocating contiguous memory buffers (if applicable),
|
|
29
|
+
registering backward hooks, and preparing the gradients for accumulation.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def unbind(self):
|
|
34
|
+
"""
|
|
35
|
+
Cleans up the bucket state.
|
|
36
|
+
|
|
37
|
+
Removes hooks, deallocates buffers, and detaches gradients.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
def zero_grad(self):
|
|
42
|
+
"""
|
|
43
|
+
Zeros out the gradients and resets accumulation counters.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def mark_sync(self):
|
|
48
|
+
"""
|
|
49
|
+
Marks this bucket as synchronized.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LocalGradientBucket(AbstractGradientBucket):
|
|
54
|
+
"""
|
|
55
|
+
A bucket for parameters that do not require distributed synchronization.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, params: list[nn.Parameter]):
|
|
59
|
+
"""
|
|
60
|
+
Constructs a LocalGradientBucket.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
params: List of parameters to manage.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
self._params = params
|
|
67
|
+
|
|
68
|
+
def bind(self):
|
|
69
|
+
"""
|
|
70
|
+
No-op for local buckets as they do not require special buffering.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def unbind(self):
|
|
74
|
+
"""
|
|
75
|
+
No-op for local buckets.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def wait(self):
|
|
79
|
+
"""
|
|
80
|
+
No-op as no async communication is performed.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def zero_grad(self):
|
|
85
|
+
"""
|
|
86
|
+
Directly zeros the grad attribute of the parameters.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
for param in self._params:
|
|
90
|
+
param.grad = None
|
|
91
|
+
|
|
92
|
+
def mark_sync(self):
|
|
93
|
+
"""
|
|
94
|
+
No-op for local buckets.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class AccumulationCounter:
|
|
99
|
+
"""
|
|
100
|
+
Tracks the number of gradient accumulation steps for a set of parameters.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, require_accumulations: int, parameters: list[nn.Parameter]):
|
|
104
|
+
"""
|
|
105
|
+
Constructs an AccumulationCounter.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
require_accumulations: Number of accumulations required before sync.
|
|
109
|
+
parameters: List of parameters to track.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
self._require_accumulations = require_accumulations
|
|
113
|
+
self._param_to_sync_count = {param: 0 for param in parameters}
|
|
114
|
+
|
|
115
|
+
def reset(self):
|
|
116
|
+
"""
|
|
117
|
+
Resets all counters to zero.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
self._param_to_sync_count = {param: 0 for param in self._param_to_sync_count}
|
|
121
|
+
|
|
122
|
+
def update(self, param: nn.Parameter):
|
|
123
|
+
"""
|
|
124
|
+
Increments the counter for a specific parameter.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
param: The parameter that finished a backward step.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
self._param_to_sync_count[param] += 1
|
|
131
|
+
|
|
132
|
+
def is_ready(self) -> bool:
|
|
133
|
+
"""
|
|
134
|
+
Checks if all parameters have reached the required number of accumulations.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
True if synchronization can proceed.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
return all(x == self._require_accumulations for x in self._param_to_sync_count.values())
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class SyncGradientBucket(AbstractGradientBucket):
|
|
144
|
+
"""
|
|
145
|
+
A bucket that manages a contiguous memory buffer for gradients and performs async reduction.
|
|
146
|
+
|
|
147
|
+
This bucket flattens the gradients of its parameters into a single contiguous
|
|
148
|
+
Tensor to enable efficient batched all-reduce operations.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
parameters: list[nn.Parameter],
|
|
154
|
+
require_accumulations: int,
|
|
155
|
+
device: torch.device,
|
|
156
|
+
grad_dtype: torch.dtype,
|
|
157
|
+
reduce_mesh: DeviceMesh,
|
|
158
|
+
communicate_stream: torch.cuda.Stream
|
|
159
|
+
):
|
|
160
|
+
"""
|
|
161
|
+
Constructs a SyncGradientBucket.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
parameters: List of parameters to manage.
|
|
165
|
+
require_accumulations: Number of accumulations before triggering reduce.
|
|
166
|
+
device: Device where parameters reside.
|
|
167
|
+
grad_dtype: Data type for the gradients.
|
|
168
|
+
reduce_mesh: DeviceMesh on which reduction happens.
|
|
169
|
+
communicate_stream: Stream where all the asynchronous communications will be scheduled
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
if not all(isinstance(x.data, DTensor) for x in parameters):
|
|
173
|
+
raise ValueError("All parameters passed in synchronizable bucket should contain DTensor data")
|
|
174
|
+
|
|
175
|
+
self._params = parameters
|
|
176
|
+
self._accum_counter = AccumulationCounter(require_accumulations, parameters)
|
|
177
|
+
self._device = device
|
|
178
|
+
self._grad_dtype = grad_dtype
|
|
179
|
+
# iterate from innermost to outermost group
|
|
180
|
+
self._reduce_groups: list[dist.ProcessGroup] = reduce_mesh.get_all_groups()[::-1]
|
|
181
|
+
|
|
182
|
+
self._buffer: Tensor | None = None
|
|
183
|
+
self._hooks: list[RemovableHandle] | None = None
|
|
184
|
+
|
|
185
|
+
self._communicate_stream = communicate_stream
|
|
186
|
+
self._ready_to_sync = False
|
|
187
|
+
|
|
188
|
+
def _bind_buffer(self):
|
|
189
|
+
"""
|
|
190
|
+
Allocates the flat buffer and redirects parameter gradients to view into it.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
buffer_size = sum(cast(DTensor, param.data).to_local().numel() for param in self._params)
|
|
194
|
+
|
|
195
|
+
self._buffer = torch.zeros(
|
|
196
|
+
(buffer_size,),
|
|
197
|
+
dtype=self._grad_dtype,
|
|
198
|
+
device=self._device
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
offset = 0
|
|
202
|
+
|
|
203
|
+
for param in self._params:
|
|
204
|
+
data = cast(DTensor, param.data)
|
|
205
|
+
local_param = data.to_local()
|
|
206
|
+
|
|
207
|
+
local_grad = self._buffer[offset:offset + local_param.numel()].view(local_param.shape)
|
|
208
|
+
|
|
209
|
+
param.grad = dist_grad_from_local(data, local_grad)
|
|
210
|
+
|
|
211
|
+
offset += local_param.numel()
|
|
212
|
+
|
|
213
|
+
@torch.no_grad()
|
|
214
|
+
def _post_accumulation_hook(self, param: nn.Parameter):
|
|
215
|
+
"""
|
|
216
|
+
Hook executed after backward pass for a parameter.
|
|
217
|
+
|
|
218
|
+
Updates the accumulation counter and triggers the asynchronous all-reduce
|
|
219
|
+
if the bucket is ready.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
param: The parameter that finished backward pass.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
self._accum_counter.update(param)
|
|
226
|
+
|
|
227
|
+
if not self._accum_counter.is_ready():
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
if self._ready_to_sync:
|
|
231
|
+
raise ValueError("Tried to accumulate, but synchronization was not performed")
|
|
232
|
+
|
|
233
|
+
with record_function("Gradient Sync"):
|
|
234
|
+
# wait for backward operation is complete
|
|
235
|
+
self._communicate_stream.wait_stream(torch.cuda.current_stream())
|
|
236
|
+
# execute all sync operations in sequential order (to ensure
|
|
237
|
+
# data safety), but in a DIFFERENT stream
|
|
238
|
+
with torch.cuda.stream(self._communicate_stream):
|
|
239
|
+
for group in self._reduce_groups:
|
|
240
|
+
dist.all_reduce(
|
|
241
|
+
self._buffer,
|
|
242
|
+
op=dist.ReduceOp.SUM,
|
|
243
|
+
group=group
|
|
244
|
+
)
|
|
245
|
+
self._ready_to_sync = True
|
|
246
|
+
|
|
247
|
+
def _bind_hooks(self):
|
|
248
|
+
"""
|
|
249
|
+
Registers post-accumulate hooks on all parameters.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
hooks = []
|
|
253
|
+
for param in self._params:
|
|
254
|
+
hooks.append(param.register_post_accumulate_grad_hook(self._post_accumulation_hook))
|
|
255
|
+
self._hooks = hooks
|
|
256
|
+
|
|
257
|
+
@torch.no_grad()
|
|
258
|
+
def bind(self):
|
|
259
|
+
"""
|
|
260
|
+
Allocates the contiguous buffer and registers hooks.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
self._bind_buffer()
|
|
264
|
+
self._bind_hooks()
|
|
265
|
+
|
|
266
|
+
def _unbind_buffer(self):
|
|
267
|
+
"""
|
|
268
|
+
Deallocates the buffer and clears parameter gradients.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
self._buffer = None
|
|
272
|
+
|
|
273
|
+
for param in self._params:
|
|
274
|
+
param.grad = None
|
|
275
|
+
|
|
276
|
+
def _unbind_hooks(self):
|
|
277
|
+
"""
|
|
278
|
+
Removes all registered hooks.
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
if self._hooks is None:
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
for hook in self._hooks:
|
|
285
|
+
hook.remove()
|
|
286
|
+
self._hooks = None
|
|
287
|
+
|
|
288
|
+
@torch.no_grad()
|
|
289
|
+
def unbind(self):
|
|
290
|
+
"""
|
|
291
|
+
Cleans up buffer and hooks.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
self._unbind_buffer()
|
|
295
|
+
self._unbind_hooks()
|
|
296
|
+
|
|
297
|
+
@torch.no_grad()
|
|
298
|
+
def zero_grad(self):
|
|
299
|
+
"""
|
|
300
|
+
Zeros the contiguous buffer, resets counters, and marks params as awaiting sync.
|
|
301
|
+
|
|
302
|
+
Raises:
|
|
303
|
+
ValueError: If the buffer is not initialized (call bind first).
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
buffer = self._buffer
|
|
307
|
+
if buffer is None:
|
|
308
|
+
raise ValueError("Buffer is not initialized")
|
|
309
|
+
|
|
310
|
+
buffer.zero_()
|
|
311
|
+
self._accum_counter.reset()
|
|
312
|
+
|
|
313
|
+
def mark_sync(self):
|
|
314
|
+
if not self._ready_to_sync:
|
|
315
|
+
raise ValueError("This bucket is not ready for sync.")
|
|
316
|
+
|
|
317
|
+
self._ready_to_sync = False
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from torch import Tensor
|
|
2
|
+
from torch.distributed.tensor import DTensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def dist_grad_from_local(data: DTensor, local_grad: Tensor) -> DTensor:
|
|
6
|
+
"""
|
|
7
|
+
Constructs a DTensor gradient from a local tensor using data placement info.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
data: The original parameter DTensor (source of metadata).
|
|
11
|
+
local_grad: The local tensor containing gradient data.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
A new DTensor wrapping the local gradient.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
return DTensor.from_local(
|
|
18
|
+
local_grad,
|
|
19
|
+
shape=data.shape,
|
|
20
|
+
stride=data.stride(),
|
|
21
|
+
device_mesh=data.device_mesh,
|
|
22
|
+
placements=data.placements
|
|
23
|
+
)
|