d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,169 @@
1
+ import math
2
+
3
+ import torch.distributed as dist
4
+ import torch.nn.utils
5
+ from torch import nn
6
+ from torch.autograd.profiler import record_function
7
+ from torch.distributed import DeviceMesh
8
+ from torch.distributed.tensor import DTensor
9
+
10
+ from d9d.internals.grad_norm.group import ParametersForNorm
11
+
12
+
13
+ def _reduce_op_from_norm_type(norm_type: float) -> dist.ReduceOp.RedOpType:
14
+ if math.isinf(norm_type):
15
+ return dist.ReduceOp.MAX
16
+ else:
17
+ return dist.ReduceOp.SUM
18
+
19
+
20
+ def _parameter_to_local_grad(parameter: nn.Parameter) -> torch.Tensor:
21
+ grad = parameter.grad
22
+
23
+ if grad is None:
24
+ raise ValueError("None grad detected")
25
+
26
+ if isinstance(grad, DTensor):
27
+ return grad.to_local()
28
+ else:
29
+ return grad
30
+
31
+
32
+ def _get_local_norm_pow(
33
+ parameters: list[nn.Parameter],
34
+ norm_type: float
35
+ ) -> torch.Tensor:
36
+ # calculates for local
37
+
38
+ if len(parameters) == 0:
39
+ return torch.tensor(0.0, device="cuda")
40
+
41
+ norm_val = torch.nn.utils.get_total_norm(
42
+ [_parameter_to_local_grad(x) for x in parameters],
43
+ norm_type=norm_type,
44
+ foreach=True,
45
+ error_if_nonfinite=False
46
+ )
47
+
48
+ if math.isinf(norm_type):
49
+ return norm_val
50
+ else:
51
+ return norm_val ** norm_type
52
+
53
+
54
+ def _get_global_norm_pow_horizontal(
55
+ parameter_groups: ParametersForNorm,
56
+ norm_type: float
57
+ ) -> torch.Tensor:
58
+ # calculates for horizontal parallelism
59
+ if len(parameter_groups) == 0:
60
+ return torch.tensor(0.0, device="cuda")
61
+
62
+ norms: list[torch.Tensor] = []
63
+ works: list[dist.Work] = []
64
+ for group, group_params in parameter_groups.items():
65
+ local_norm_pow = _get_local_norm_pow(group_params, norm_type=norm_type)
66
+ if group.shard_meshes is not None:
67
+ if len(group.shard_meshes) != 1:
68
+ raise ValueError(
69
+ "Currently we do not support calculating norm for tensors that are sharded on multiple dims - feel "
70
+ "free to file an issue if you need it."
71
+ )
72
+ process_group = group.shard_meshes[0].get_group()
73
+ work = dist.all_reduce(
74
+ local_norm_pow,
75
+ op=_reduce_op_from_norm_type(norm_type),
76
+ group=process_group,
77
+ async_op=True
78
+ )
79
+ works.append(work)
80
+ norms.append(local_norm_pow)
81
+
82
+ for work in works:
83
+ work.wait()
84
+
85
+ norms_total = torch.stack(norms, dim=0)
86
+
87
+ if math.isinf(norm_type):
88
+ return norms_total.max()
89
+ else:
90
+ return norms_total.sum()
91
+
92
+
93
+ def _get_global_norm_pow_pp(
94
+ parameter_groups: ParametersForNorm,
95
+ norm_type: float,
96
+ pp_mesh: DeviceMesh | None
97
+ ) -> torch.Tensor:
98
+ norm = _get_global_norm_pow_horizontal(
99
+ parameter_groups=parameter_groups,
100
+ norm_type=norm_type
101
+ )
102
+ if pp_mesh is not None:
103
+ dist.all_reduce(norm, op=_reduce_op_from_norm_type(norm_type), group=pp_mesh.get_group())
104
+ return norm
105
+
106
+
107
+ def _clip_grad_with_norm_(
108
+ parameter_groups: ParametersForNorm,
109
+ max_norm: float,
110
+ total_norm: torch.Tensor
111
+ ):
112
+ clip_coef = max_norm / (total_norm + 1e-6)
113
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
114
+
115
+ for group in parameter_groups.values():
116
+ grads = [_parameter_to_local_grad(x) for x in group]
117
+ torch._foreach_mul_(grads, clip_coef_clamped)
118
+
119
+
120
+ def clip_grad_norm_distributed_(
121
+ parameter_groups: ParametersForNorm,
122
+ max_norm: float | None,
123
+ norm_type: float,
124
+ pp_mesh: DeviceMesh | None
125
+ ) -> torch.Tensor:
126
+ """
127
+ Clips gradient norms in a fully distributed environment.
128
+
129
+ This function calculates the global gradient norm across all dimensions of parallelism
130
+ (Horizontal - DP/CP/TP/EP/..., and Pipeline) and scales the gradients in-place to ensure the norm
131
+ does not exceed max_norm.
132
+
133
+ It accurately handles DTensors by identifying their sharding placements and performing
134
+ reductions only on the necessary process groups.
135
+
136
+ Overlaps communication and computation if possible.
137
+
138
+ Args:
139
+ parameter_groups: Dictionary grouping parameters by synchronization requirements,
140
+ typically created by `group_parameters_for_norm`.
141
+ max_norm: The maximum allowed norm of the gradients. If None, the function
142
+ calculates and returns the global norm without modifying the gradients.
143
+ norm_type: The type of the norm to calculate (e.g., 2.0 for L2 norm, inf for max norm).
144
+ pp_mesh: The device mesh representing the pipeline parallel dimension, needed
145
+ to reduce norms across pipeline stages.
146
+
147
+ Returns:
148
+ The calculated global gradient norm.
149
+ """
150
+
151
+ with record_function("Gradient Clipping"):
152
+ global_norm_pow = _get_global_norm_pow_pp(
153
+ parameter_groups=parameter_groups,
154
+ norm_type=norm_type,
155
+ pp_mesh=pp_mesh
156
+ )
157
+ if math.isinf(norm_type):
158
+ global_norm = global_norm_pow
159
+ else:
160
+ global_norm = global_norm_pow ** (1.0 / norm_type)
161
+
162
+ if max_norm:
163
+ _clip_grad_with_norm_(
164
+ parameter_groups,
165
+ max_norm=max_norm,
166
+ total_norm=global_norm
167
+ )
168
+
169
+ return global_norm
@@ -0,0 +1,14 @@
1
+ """
2
+ Gradient synchronization utilities.
3
+
4
+ This package provides the infrastructure for manual gradient bucketing and
5
+ asynchronous reduction, similar to DistributedDataParallel but exposed
6
+ for internal framework usage with DTensors.
7
+ """
8
+
9
+
10
+ from .synchronizer import GradientSynchronizer
11
+
12
+ __all__ = [
13
+ "GradientSynchronizer"
14
+ ]
@@ -0,0 +1,317 @@
1
+ import abc
2
+ from typing import cast
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch import Tensor, nn
7
+ from torch.autograd.profiler import record_function
8
+ from torch.distributed import DeviceMesh
9
+ from torch.distributed.tensor import DTensor
10
+ from torch.utils.hooks import RemovableHandle
11
+
12
+ from .placement_helper import dist_grad_from_local
13
+
14
+
15
+ class AbstractGradientBucket(abc.ABC):
16
+ """
17
+ Interface for a bucket containing a subset of model parameters.
18
+
19
+ A bucket manages the memory layout and synchronization lifecycle of the
20
+ gradients associated with its parameters.
21
+ """
22
+
23
+ @abc.abstractmethod
24
+ def bind(self):
25
+ """
26
+ Initializes the bucket state.
27
+
28
+ This involves allocating contiguous memory buffers (if applicable),
29
+ registering backward hooks, and preparing the gradients for accumulation.
30
+ """
31
+
32
+ @abc.abstractmethod
33
+ def unbind(self):
34
+ """
35
+ Cleans up the bucket state.
36
+
37
+ Removes hooks, deallocates buffers, and detaches gradients.
38
+ """
39
+
40
+ @abc.abstractmethod
41
+ def zero_grad(self):
42
+ """
43
+ Zeros out the gradients and resets accumulation counters.
44
+ """
45
+
46
+ @abc.abstractmethod
47
+ def mark_sync(self):
48
+ """
49
+ Marks this bucket as synchronized.
50
+ """
51
+
52
+
53
+ class LocalGradientBucket(AbstractGradientBucket):
54
+ """
55
+ A bucket for parameters that do not require distributed synchronization.
56
+ """
57
+
58
+ def __init__(self, params: list[nn.Parameter]):
59
+ """
60
+ Constructs a LocalGradientBucket.
61
+
62
+ Args:
63
+ params: List of parameters to manage.
64
+ """
65
+
66
+ self._params = params
67
+
68
+ def bind(self):
69
+ """
70
+ No-op for local buckets as they do not require special buffering.
71
+ """
72
+
73
+ def unbind(self):
74
+ """
75
+ No-op for local buckets.
76
+ """
77
+
78
+ def wait(self):
79
+ """
80
+ No-op as no async communication is performed.
81
+ """
82
+
83
+ @torch.no_grad()
84
+ def zero_grad(self):
85
+ """
86
+ Directly zeros the grad attribute of the parameters.
87
+ """
88
+
89
+ for param in self._params:
90
+ param.grad = None
91
+
92
+ def mark_sync(self):
93
+ """
94
+ No-op for local buckets.
95
+ """
96
+
97
+
98
+ class AccumulationCounter:
99
+ """
100
+ Tracks the number of gradient accumulation steps for a set of parameters.
101
+ """
102
+
103
+ def __init__(self, require_accumulations: int, parameters: list[nn.Parameter]):
104
+ """
105
+ Constructs an AccumulationCounter.
106
+
107
+ Args:
108
+ require_accumulations: Number of accumulations required before sync.
109
+ parameters: List of parameters to track.
110
+ """
111
+
112
+ self._require_accumulations = require_accumulations
113
+ self._param_to_sync_count = {param: 0 for param in parameters}
114
+
115
+ def reset(self):
116
+ """
117
+ Resets all counters to zero.
118
+ """
119
+
120
+ self._param_to_sync_count = {param: 0 for param in self._param_to_sync_count}
121
+
122
+ def update(self, param: nn.Parameter):
123
+ """
124
+ Increments the counter for a specific parameter.
125
+
126
+ Args:
127
+ param: The parameter that finished a backward step.
128
+ """
129
+
130
+ self._param_to_sync_count[param] += 1
131
+
132
+ def is_ready(self) -> bool:
133
+ """
134
+ Checks if all parameters have reached the required number of accumulations.
135
+
136
+ Returns:
137
+ True if synchronization can proceed.
138
+ """
139
+
140
+ return all(x == self._require_accumulations for x in self._param_to_sync_count.values())
141
+
142
+
143
+ class SyncGradientBucket(AbstractGradientBucket):
144
+ """
145
+ A bucket that manages a contiguous memory buffer for gradients and performs async reduction.
146
+
147
+ This bucket flattens the gradients of its parameters into a single contiguous
148
+ Tensor to enable efficient batched all-reduce operations.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ parameters: list[nn.Parameter],
154
+ require_accumulations: int,
155
+ device: torch.device,
156
+ grad_dtype: torch.dtype,
157
+ reduce_mesh: DeviceMesh,
158
+ communicate_stream: torch.cuda.Stream
159
+ ):
160
+ """
161
+ Constructs a SyncGradientBucket.
162
+
163
+ Args:
164
+ parameters: List of parameters to manage.
165
+ require_accumulations: Number of accumulations before triggering reduce.
166
+ device: Device where parameters reside.
167
+ grad_dtype: Data type for the gradients.
168
+ reduce_mesh: DeviceMesh on which reduction happens.
169
+ communicate_stream: Stream where all the asynchronous communications will be scheduled
170
+ """
171
+
172
+ if not all(isinstance(x.data, DTensor) for x in parameters):
173
+ raise ValueError("All parameters passed in synchronizable bucket should contain DTensor data")
174
+
175
+ self._params = parameters
176
+ self._accum_counter = AccumulationCounter(require_accumulations, parameters)
177
+ self._device = device
178
+ self._grad_dtype = grad_dtype
179
+ # iterate from innermost to outermost group
180
+ self._reduce_groups: list[dist.ProcessGroup] = reduce_mesh.get_all_groups()[::-1]
181
+
182
+ self._buffer: Tensor | None = None
183
+ self._hooks: list[RemovableHandle] | None = None
184
+
185
+ self._communicate_stream = communicate_stream
186
+ self._ready_to_sync = False
187
+
188
+ def _bind_buffer(self):
189
+ """
190
+ Allocates the flat buffer and redirects parameter gradients to view into it.
191
+ """
192
+
193
+ buffer_size = sum(cast(DTensor, param.data).to_local().numel() for param in self._params)
194
+
195
+ self._buffer = torch.zeros(
196
+ (buffer_size,),
197
+ dtype=self._grad_dtype,
198
+ device=self._device
199
+ )
200
+
201
+ offset = 0
202
+
203
+ for param in self._params:
204
+ data = cast(DTensor, param.data)
205
+ local_param = data.to_local()
206
+
207
+ local_grad = self._buffer[offset:offset + local_param.numel()].view(local_param.shape)
208
+
209
+ param.grad = dist_grad_from_local(data, local_grad)
210
+
211
+ offset += local_param.numel()
212
+
213
+ @torch.no_grad()
214
+ def _post_accumulation_hook(self, param: nn.Parameter):
215
+ """
216
+ Hook executed after backward pass for a parameter.
217
+
218
+ Updates the accumulation counter and triggers the asynchronous all-reduce
219
+ if the bucket is ready.
220
+
221
+ Args:
222
+ param: The parameter that finished backward pass.
223
+ """
224
+
225
+ self._accum_counter.update(param)
226
+
227
+ if not self._accum_counter.is_ready():
228
+ return
229
+
230
+ if self._ready_to_sync:
231
+ raise ValueError("Tried to accumulate, but synchronization was not performed")
232
+
233
+ with record_function("Gradient Sync"):
234
+ # wait for backward operation is complete
235
+ self._communicate_stream.wait_stream(torch.cuda.current_stream())
236
+ # execute all sync operations in sequential order (to ensure
237
+ # data safety), but in a DIFFERENT stream
238
+ with torch.cuda.stream(self._communicate_stream):
239
+ for group in self._reduce_groups:
240
+ dist.all_reduce(
241
+ self._buffer,
242
+ op=dist.ReduceOp.SUM,
243
+ group=group
244
+ )
245
+ self._ready_to_sync = True
246
+
247
+ def _bind_hooks(self):
248
+ """
249
+ Registers post-accumulate hooks on all parameters.
250
+ """
251
+
252
+ hooks = []
253
+ for param in self._params:
254
+ hooks.append(param.register_post_accumulate_grad_hook(self._post_accumulation_hook))
255
+ self._hooks = hooks
256
+
257
+ @torch.no_grad()
258
+ def bind(self):
259
+ """
260
+ Allocates the contiguous buffer and registers hooks.
261
+ """
262
+
263
+ self._bind_buffer()
264
+ self._bind_hooks()
265
+
266
+ def _unbind_buffer(self):
267
+ """
268
+ Deallocates the buffer and clears parameter gradients.
269
+ """
270
+
271
+ self._buffer = None
272
+
273
+ for param in self._params:
274
+ param.grad = None
275
+
276
+ def _unbind_hooks(self):
277
+ """
278
+ Removes all registered hooks.
279
+ """
280
+
281
+ if self._hooks is None:
282
+ return
283
+
284
+ for hook in self._hooks:
285
+ hook.remove()
286
+ self._hooks = None
287
+
288
+ @torch.no_grad()
289
+ def unbind(self):
290
+ """
291
+ Cleans up buffer and hooks.
292
+ """
293
+
294
+ self._unbind_buffer()
295
+ self._unbind_hooks()
296
+
297
+ @torch.no_grad()
298
+ def zero_grad(self):
299
+ """
300
+ Zeros the contiguous buffer, resets counters, and marks params as awaiting sync.
301
+
302
+ Raises:
303
+ ValueError: If the buffer is not initialized (call bind first).
304
+ """
305
+
306
+ buffer = self._buffer
307
+ if buffer is None:
308
+ raise ValueError("Buffer is not initialized")
309
+
310
+ buffer.zero_()
311
+ self._accum_counter.reset()
312
+
313
+ def mark_sync(self):
314
+ if not self._ready_to_sync:
315
+ raise ValueError("This bucket is not ready for sync.")
316
+
317
+ self._ready_to_sync = False
@@ -0,0 +1,23 @@
1
+ from torch import Tensor
2
+ from torch.distributed.tensor import DTensor
3
+
4
+
5
+ def dist_grad_from_local(data: DTensor, local_grad: Tensor) -> DTensor:
6
+ """
7
+ Constructs a DTensor gradient from a local tensor using data placement info.
8
+
9
+ Args:
10
+ data: The original parameter DTensor (source of metadata).
11
+ local_grad: The local tensor containing gradient data.
12
+
13
+ Returns:
14
+ A new DTensor wrapping the local gradient.
15
+ """
16
+
17
+ return DTensor.from_local(
18
+ local_grad,
19
+ shape=data.shape,
20
+ stride=data.stride(),
21
+ device_mesh=data.device_mesh,
22
+ placements=data.placements
23
+ )