d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.distributed import DeviceMesh
|
|
8
|
+
from torch.distributed.tensor import DTensor, Replicate, Shard
|
|
9
|
+
|
|
10
|
+
from .bucket import AbstractGradientBucket, LocalGradientBucket, SyncGradientBucket
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _find_reduce_mesh(data: DTensor) -> DeviceMesh | None:
|
|
14
|
+
"""
|
|
15
|
+
Identifies the sub-mesh required for gradient reduction based on tensor placements.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data: The parameter tensor.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The DeviceMesh subset needed for reduction, or None if no reduction is needed.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
reduce_dims: set[int] = set()
|
|
25
|
+
|
|
26
|
+
for dim_i, dim_placement in enumerate(data.placements):
|
|
27
|
+
match dim_placement:
|
|
28
|
+
case Replicate():
|
|
29
|
+
reduce_dims.add(dim_i)
|
|
30
|
+
case Shard():
|
|
31
|
+
pass
|
|
32
|
+
case _:
|
|
33
|
+
raise ValueError(f"Unknown grad placement: {dim_placement}")
|
|
34
|
+
|
|
35
|
+
if len(reduce_dims) == 0:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
device_mesh: DeviceMesh = data.device_mesh
|
|
39
|
+
|
|
40
|
+
# we are sure that device mesh contain dim names so we cast(...)
|
|
41
|
+
mesh_dim_names = cast(tuple[str, ...], device_mesh.mesh_dim_names)
|
|
42
|
+
reduce_mesh = device_mesh[tuple(
|
|
43
|
+
mesh_dim_names[dim_i] for dim_i in reduce_dims
|
|
44
|
+
)]
|
|
45
|
+
|
|
46
|
+
return reduce_mesh
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass(frozen=True)
|
|
50
|
+
class _ParameterGroupMarker:
|
|
51
|
+
"""
|
|
52
|
+
Identifier for grouping compatible parameters into buckets.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
group_i: int
|
|
56
|
+
reduce_mesh: DeviceMesh | None
|
|
57
|
+
device: torch.device
|
|
58
|
+
grad_dtype: torch.dtype | None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _group_params_for_buckets(
|
|
62
|
+
param_groups: list[list[nn.Parameter]]
|
|
63
|
+
) -> dict[_ParameterGroupMarker, list[nn.Parameter]]:
|
|
64
|
+
"""
|
|
65
|
+
Sorts parameters into groups based on their synchronization requirements.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
param_groups: List of parameter groups (from optimizer).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Dictionary mapping group markers to lists of parameters.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
regrouped_params = defaultdict(list)
|
|
75
|
+
for param_group_i, param_group in enumerate(param_groups):
|
|
76
|
+
# iterate in reverse order to maximize overlap
|
|
77
|
+
for param in param_group[::-1]:
|
|
78
|
+
if not param.requires_grad:
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
if not isinstance(param.data, DTensor):
|
|
82
|
+
raise TypeError("All params should be DTensors in a distributed setup")
|
|
83
|
+
|
|
84
|
+
reduce_mesh = _find_reduce_mesh(param.data)
|
|
85
|
+
|
|
86
|
+
group = _ParameterGroupMarker(
|
|
87
|
+
group_i=param_group_i,
|
|
88
|
+
reduce_mesh=reduce_mesh,
|
|
89
|
+
device=param.data.device,
|
|
90
|
+
grad_dtype=param.grad_dtype
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
regrouped_params[group].append(param)
|
|
94
|
+
|
|
95
|
+
return regrouped_params
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _make_bucket(
|
|
99
|
+
require_accumulations: int,
|
|
100
|
+
group_marker: _ParameterGroupMarker,
|
|
101
|
+
parameters: list[nn.Parameter],
|
|
102
|
+
communicate_stream: torch.cuda.Stream
|
|
103
|
+
) -> AbstractGradientBucket:
|
|
104
|
+
"""
|
|
105
|
+
Factory function to create the appropriate bucket type.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
if group_marker.reduce_mesh is None:
|
|
109
|
+
return LocalGradientBucket(parameters)
|
|
110
|
+
else:
|
|
111
|
+
if group_marker.grad_dtype is None:
|
|
112
|
+
raise ValueError("Gradient dtype could not be None")
|
|
113
|
+
|
|
114
|
+
return SyncGradientBucket(
|
|
115
|
+
parameters=parameters,
|
|
116
|
+
require_accumulations=require_accumulations,
|
|
117
|
+
device=group_marker.device,
|
|
118
|
+
grad_dtype=group_marker.grad_dtype,
|
|
119
|
+
reduce_mesh=group_marker.reduce_mesh,
|
|
120
|
+
communicate_stream=communicate_stream
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _fill_buckets(
|
|
125
|
+
param_groups: dict[_ParameterGroupMarker, list[nn.Parameter]],
|
|
126
|
+
bucket_size_mb: int,
|
|
127
|
+
require_accumulations: int,
|
|
128
|
+
communicate_stream: torch.cuda.Stream
|
|
129
|
+
) -> list[AbstractGradientBucket]:
|
|
130
|
+
"""
|
|
131
|
+
Splits grouped parameters into buckets based on size constraints.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
param_groups: Parameters grouped by sync requirements.
|
|
135
|
+
bucket_size_mb: Max size for each bucket in megabytes.
|
|
136
|
+
require_accumulations: Number of gradient accumulations required before syncing gradients.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
List of configured gradient buckets.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
# TODO: Better grouping - probably we could trace autograd graph and use some topological clustering here
|
|
143
|
+
# TODO: to maximize overlap even better - current implementation just iterates over parameters in reverse order
|
|
144
|
+
buckets = []
|
|
145
|
+
|
|
146
|
+
bucket_size = bucket_size_mb * 1024 * 1024
|
|
147
|
+
|
|
148
|
+
for param_group_marker, param_group in param_groups.items():
|
|
149
|
+
current_bucket_size = 0
|
|
150
|
+
unfinished_bucket: list[nn.Parameter] = []
|
|
151
|
+
for param in param_group:
|
|
152
|
+
param_bytes = param.numel() * param.element_size()
|
|
153
|
+
if current_bucket_size + param_bytes >= bucket_size and unfinished_bucket:
|
|
154
|
+
buckets.append(_make_bucket(
|
|
155
|
+
require_accumulations=require_accumulations,
|
|
156
|
+
group_marker=param_group_marker,
|
|
157
|
+
parameters=unfinished_bucket,
|
|
158
|
+
communicate_stream=communicate_stream
|
|
159
|
+
))
|
|
160
|
+
unfinished_bucket = []
|
|
161
|
+
current_bucket_size = 0
|
|
162
|
+
|
|
163
|
+
unfinished_bucket.append(param)
|
|
164
|
+
current_bucket_size += param_bytes
|
|
165
|
+
|
|
166
|
+
if unfinished_bucket:
|
|
167
|
+
buckets.append(_make_bucket(
|
|
168
|
+
require_accumulations=require_accumulations,
|
|
169
|
+
group_marker=param_group_marker,
|
|
170
|
+
parameters=unfinished_bucket,
|
|
171
|
+
communicate_stream=communicate_stream
|
|
172
|
+
))
|
|
173
|
+
return buckets
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class GradientSynchronizer:
|
|
177
|
+
"""
|
|
178
|
+
Manages gradient synchronization for distributed training.
|
|
179
|
+
|
|
180
|
+
This class handles the bucketing of parameters, memory allocation for flat
|
|
181
|
+
gradient buffers, and the orchestration of asynchronous all-reduce operations
|
|
182
|
+
during the backward pass.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
param_groups: list[list[nn.Parameter]],
|
|
188
|
+
bucket_size_mb: int,
|
|
189
|
+
require_accumulations: int
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Constructs a GradientSynchronizer.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
param_groups: List of parameter groups.
|
|
196
|
+
bucket_size_mb: Maximal size of a single gradient bucket in MB.
|
|
197
|
+
require_accumulations: Number of micro-batches to accumulate before reducing.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
self._param_groups = param_groups
|
|
201
|
+
self._bucket_size_mb = bucket_size_mb
|
|
202
|
+
self._require_accumulations = require_accumulations
|
|
203
|
+
|
|
204
|
+
self._communicate_stream: torch.cuda.Stream | None = None
|
|
205
|
+
self._can_sync: bool
|
|
206
|
+
self._buckets: list[AbstractGradientBucket] = []
|
|
207
|
+
|
|
208
|
+
def bind(self):
|
|
209
|
+
"""
|
|
210
|
+
Initializes the synchronizer for training.
|
|
211
|
+
|
|
212
|
+
Groups parameters, creates buckets, allocates memory, and registers hooks.
|
|
213
|
+
Must be called before the backward pass.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
stream = torch.cuda.Stream()
|
|
217
|
+
self._communicate_stream = stream
|
|
218
|
+
self._buckets = _fill_buckets(
|
|
219
|
+
_group_params_for_buckets(self._param_groups),
|
|
220
|
+
bucket_size_mb=self._bucket_size_mb,
|
|
221
|
+
require_accumulations=self._require_accumulations,
|
|
222
|
+
communicate_stream=stream
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
for bucket in self._buckets:
|
|
226
|
+
bucket.bind()
|
|
227
|
+
|
|
228
|
+
def unbind(self):
|
|
229
|
+
"""
|
|
230
|
+
Releases resources.
|
|
231
|
+
|
|
232
|
+
Destroys buckets, frees memory buffers, and removes hooks.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
for bucket in self._buckets:
|
|
236
|
+
bucket.unbind()
|
|
237
|
+
|
|
238
|
+
self._buckets = []
|
|
239
|
+
self._communicate_stream = None
|
|
240
|
+
|
|
241
|
+
def wait(self):
|
|
242
|
+
"""
|
|
243
|
+
Waits for all bucket operations (async reductions) to complete.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
torch.cuda.current_stream().wait_stream(self._communicate_stream)
|
|
247
|
+
|
|
248
|
+
for bucket in self._buckets:
|
|
249
|
+
bucket.mark_sync()
|
|
250
|
+
|
|
251
|
+
def zero_grad(self):
|
|
252
|
+
"""
|
|
253
|
+
Resets gradients and accumulation counters for all managed parameters.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
for bucket in self._buckets:
|
|
257
|
+
bucket.zero_grad()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline State management package.
|
|
3
|
+
|
|
4
|
+
This package provides mechanisms to store, retrieve, and synchronize state
|
|
5
|
+
across different stages of a distributed pipeline, providing global and sharded view for these states.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .api import PipelineState
|
|
9
|
+
from .handler import PipelineStateHandler
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"PipelineState",
|
|
13
|
+
"PipelineStateHandler"
|
|
14
|
+
]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PipelineState(abc.ABC):
|
|
6
|
+
"""
|
|
7
|
+
Object representing the state of a pipeline.
|
|
8
|
+
|
|
9
|
+
This class defines the interface for accessing state variables like a dictionary,
|
|
10
|
+
abstracting away whether the underlying storage is local, sharded, or global.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def __setitem__(self, key: str, value: Any):
|
|
15
|
+
"""
|
|
16
|
+
Sets a state value for a given key.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
key: The identifier for the state variable.
|
|
20
|
+
value: The value to store.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abc.abstractmethod
|
|
24
|
+
def __getitem__(self, item: str) -> Any:
|
|
25
|
+
"""
|
|
26
|
+
Retrieves a state value for a given key.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
item: The identifier for the state variable.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The value associated with the key.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def __contains__(self, item: str) -> bool:
|
|
37
|
+
"""
|
|
38
|
+
Checks if a key exists in the state.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
item: The identifier to check.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
True if the key exists, False otherwise.
|
|
45
|
+
"""
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from d9d.core.sharding import ShardingSpecLeaf
|
|
4
|
+
|
|
5
|
+
from .api import PipelineState
|
|
6
|
+
from .storage import PipelineStateStorage
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PipelineStateGlobal(PipelineState):
|
|
10
|
+
"""
|
|
11
|
+
Represents the global (unsharded) view of the pipeline state.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, storage: PipelineStateStorage):
|
|
15
|
+
"""
|
|
16
|
+
Constructs a PipelineStateGlobal object.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
storage: The underlying storage backend.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
self._storage = storage
|
|
23
|
+
|
|
24
|
+
def __setitem__(self, key: str, value: Any):
|
|
25
|
+
self._storage.store_global((key,), value)
|
|
26
|
+
|
|
27
|
+
def __getitem__(self, item: str) -> Any:
|
|
28
|
+
return self._storage.acquire_global((item,))
|
|
29
|
+
|
|
30
|
+
def __contains__(self, item: str) -> bool:
|
|
31
|
+
return self._storage.contains((item,))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PipelineStateShard(PipelineState):
|
|
35
|
+
"""
|
|
36
|
+
Represents a sharded view of the pipeline state for a specific shard ID.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, storage: PipelineStateStorage, current_shard: int):
|
|
40
|
+
"""
|
|
41
|
+
Constructs a PipelineStateShard object.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
storage: The underlying storage backend.
|
|
45
|
+
current_shard: The index of the partial shard this view represents.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
self._storage = storage
|
|
49
|
+
self._current_shard = current_shard
|
|
50
|
+
|
|
51
|
+
def __setitem__(self, key: str, value: Any):
|
|
52
|
+
self._storage.store_shard((key,), value, self._current_shard)
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, item: str) -> Any:
|
|
55
|
+
return self._storage.acquire_shard((item,), self._current_shard)
|
|
56
|
+
|
|
57
|
+
def __contains__(self, item: str) -> bool:
|
|
58
|
+
return self._storage.contains((item,))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class PipelineStateHandler:
|
|
62
|
+
"""
|
|
63
|
+
Manages the lifecycle and access patterns of pipeline states.
|
|
64
|
+
|
|
65
|
+
This handler initializes the underlying storage and provides specific views
|
|
66
|
+
(global or sharded) into that storage.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, sharding_spec: dict[str, ShardingSpecLeaf], num_shards: int):
|
|
70
|
+
"""
|
|
71
|
+
Constructs a PipelineStateHandler object.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
sharding_spec: A definition of how specific keys should be sharded.
|
|
75
|
+
num_shards: The total number of shards in the pipeline.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
self._storage = PipelineStateStorage(
|
|
79
|
+
sharding_spec={(k,): v for k, v in sharding_spec.items()},
|
|
80
|
+
num_shards=num_shards
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def global_state(self) -> PipelineState:
|
|
84
|
+
"""
|
|
85
|
+
Returns a view interface for accessing global state.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
A PipelineState interface that accesses the full, aggregated data.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
return PipelineStateGlobal(self._storage)
|
|
92
|
+
|
|
93
|
+
def sharded_state(self, shard_id: int) -> PipelineState:
|
|
94
|
+
"""
|
|
95
|
+
Returns a view interface for accessing state specific to a shard ID.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
shard_id: The index of the shard to access.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
A PipelineState interface that accesses partial data for the given shard.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
return PipelineStateShard(self._storage, shard_id)
|
|
105
|
+
|
|
106
|
+
def reset(self):
|
|
107
|
+
"""
|
|
108
|
+
Resets the underlying storage, clearing all state.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
self._storage.reset()
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections import UserDict
|
|
3
|
+
from typing import Any, TypeVar, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.utils._pytree as pytree # noqa: PLC2701
|
|
7
|
+
|
|
8
|
+
from d9d.core.sharding import ShardingSpecLeaf, SpecReplicate, SpecShard, shard_tree, unshard_tree
|
|
9
|
+
|
|
10
|
+
StateKey = tuple[str, ...]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
TMap = TypeVar("TMap")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _detach_leaf(x: TMap) -> TMap:
|
|
17
|
+
"""
|
|
18
|
+
Detaches a tensor from the computation graph if the input is a tensor.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
x: The input object.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The detached tensor or original object.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
if isinstance(x, torch.Tensor):
|
|
28
|
+
return cast(TMap, x.detach())
|
|
29
|
+
return x
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ShardedState(UserDict):
|
|
33
|
+
"""
|
|
34
|
+
Container for holding state broken down by shard indices.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PipelineStateStorage:
|
|
39
|
+
"""
|
|
40
|
+
Low-level storage backend handling sharding and aggregation of state data.
|
|
41
|
+
|
|
42
|
+
This class manages the transition between sharded data
|
|
43
|
+
and global data. It uses sharding specifications to determine
|
|
44
|
+
how to split or join data.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
sharding_spec: dict[StateKey, ShardingSpecLeaf],
|
|
50
|
+
num_shards: int
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Constructs a PipelineStateStorage object.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
sharding_spec: Dictionary mapping state keys to their sharding behaviors.
|
|
57
|
+
num_shards: Total number of shards involved in the storage.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
self._sharding_spec_orig = copy.deepcopy(sharding_spec)
|
|
61
|
+
|
|
62
|
+
self._state: dict[StateKey, Any] = {}
|
|
63
|
+
self._state_sharding_spec: dict[StateKey, ShardingSpecLeaf] = {}
|
|
64
|
+
|
|
65
|
+
self._num_shards = num_shards
|
|
66
|
+
|
|
67
|
+
def _guess_sharding_spec_for_shard(self, key: StateKey, shard: Any) -> ShardingSpecLeaf:
|
|
68
|
+
# Stack if scalar (tensor or item), cat otherwise
|
|
69
|
+
|
|
70
|
+
if key in self._sharding_spec_orig:
|
|
71
|
+
return self._sharding_spec_orig[key]
|
|
72
|
+
|
|
73
|
+
if isinstance(shard, torch.Tensor):
|
|
74
|
+
do_stack = shard.ndim == 0
|
|
75
|
+
return SpecShard(dim=0, do_stack=do_stack)
|
|
76
|
+
elif isinstance(shard, list):
|
|
77
|
+
return SpecShard(dim=0)
|
|
78
|
+
else:
|
|
79
|
+
return SpecShard(dim=0, do_stack=True)
|
|
80
|
+
|
|
81
|
+
def _guess_sharding_spec_for_global(self, key: StateKey, state: Any) -> ShardingSpecLeaf:
|
|
82
|
+
if key in self._sharding_spec_orig:
|
|
83
|
+
return self._sharding_spec_orig[key]
|
|
84
|
+
|
|
85
|
+
if isinstance(state, torch.Tensor):
|
|
86
|
+
if state.ndim == 0:
|
|
87
|
+
return SpecReplicate()
|
|
88
|
+
else:
|
|
89
|
+
return SpecShard(dim=0)
|
|
90
|
+
elif isinstance(state, list):
|
|
91
|
+
return SpecShard(dim=0)
|
|
92
|
+
else:
|
|
93
|
+
return SpecReplicate()
|
|
94
|
+
|
|
95
|
+
def store_global(self, key: StateKey, state: Any):
|
|
96
|
+
"""
|
|
97
|
+
Stores a value in the global scope.
|
|
98
|
+
|
|
99
|
+
If the key does not have a sharding spec, one will be inferred. Detaches tensors.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
key: The identifier key.
|
|
103
|
+
state: The unified value to store.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
state = pytree.tree_map(_detach_leaf, state)
|
|
107
|
+
|
|
108
|
+
if key not in self._state_sharding_spec:
|
|
109
|
+
self._state_sharding_spec[key] = self._guess_sharding_spec_for_global(key, state)
|
|
110
|
+
|
|
111
|
+
self._state[key] = state
|
|
112
|
+
|
|
113
|
+
def store_shard(self, key: StateKey, state: Any, shard_id: int):
|
|
114
|
+
"""
|
|
115
|
+
Stores a value for a specific shard index.
|
|
116
|
+
|
|
117
|
+
Raises error if attempting to shard an already global key without conversion.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
key: The identifier key.
|
|
121
|
+
state: The partial value for the shard.
|
|
122
|
+
shard_id: The index of the shard.
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ValueError: If trying to store sharded state into an unsharded container.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
if key not in self._state:
|
|
129
|
+
self._state[key] = ShardedState()
|
|
130
|
+
|
|
131
|
+
container = self._state[key]
|
|
132
|
+
|
|
133
|
+
if not isinstance(container, ShardedState):
|
|
134
|
+
raise ValueError(f"Trying to store sharded state into an unsharded one: {key}")
|
|
135
|
+
|
|
136
|
+
state = pytree.tree_map(_detach_leaf, state)
|
|
137
|
+
|
|
138
|
+
# dynamically populate sharding spec to know whether it is stacking or not
|
|
139
|
+
if key not in self._state_sharding_spec:
|
|
140
|
+
self._state_sharding_spec[key] = self._guess_sharding_spec_for_shard(key, state)
|
|
141
|
+
|
|
142
|
+
self._state[key][shard_id] = state
|
|
143
|
+
|
|
144
|
+
def _ensure_global(self, key: StateKey):
|
|
145
|
+
if key not in self._state:
|
|
146
|
+
raise ValueError(f"Cannot access non-existing state {key}")
|
|
147
|
+
|
|
148
|
+
state = self._state[key]
|
|
149
|
+
|
|
150
|
+
if not isinstance(state, ShardedState):
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
# here we know we are in ShardedState
|
|
154
|
+
|
|
155
|
+
shards = [state[shard_id] for shard_id in range(self._num_shards)]
|
|
156
|
+
resharded = unshard_tree(shards, self._state_sharding_spec[key])
|
|
157
|
+
|
|
158
|
+
self._state[key] = resharded
|
|
159
|
+
|
|
160
|
+
def _ensure_sharded(self, key: StateKey):
|
|
161
|
+
if key not in self._state:
|
|
162
|
+
raise ValueError(f"Cannot access non-existing state {key}")
|
|
163
|
+
|
|
164
|
+
state = self._state[key]
|
|
165
|
+
|
|
166
|
+
if isinstance(state, ShardedState):
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
# here we know we are in global state
|
|
170
|
+
|
|
171
|
+
sharded = shard_tree(
|
|
172
|
+
state,
|
|
173
|
+
self._state_sharding_spec[key],
|
|
174
|
+
num_shards=self._num_shards,
|
|
175
|
+
enforce_even_split=True
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
sharded_state = ShardedState({
|
|
179
|
+
shard_idx: shard for shard_idx, shard in enumerate(sharded)
|
|
180
|
+
})
|
|
181
|
+
|
|
182
|
+
self._state[key] = sharded_state
|
|
183
|
+
|
|
184
|
+
def acquire_global(self, key: StateKey) -> Any:
|
|
185
|
+
"""
|
|
186
|
+
Retrieves data for a key in its global form.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
key: The state key.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
The aggregated global data.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
self._ensure_global(key)
|
|
196
|
+
return self._state[key]
|
|
197
|
+
|
|
198
|
+
def acquire_shard(self, key: StateKey, shard: int) -> Any:
|
|
199
|
+
"""
|
|
200
|
+
Retrieves data for a key specific to a shard index.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
key: The state key.
|
|
204
|
+
shard: The shard index.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
The data slice corresponding to the shard.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
self._ensure_sharded(key)
|
|
211
|
+
state = self._state[key]
|
|
212
|
+
|
|
213
|
+
if isinstance(state, ShardedState):
|
|
214
|
+
return state[shard]
|
|
215
|
+
else:
|
|
216
|
+
return state
|
|
217
|
+
|
|
218
|
+
def contains(self, key: StateKey) -> bool:
|
|
219
|
+
"""
|
|
220
|
+
Checks if a key exists in storage.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
key: The state key.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
True if present.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
return key in self._state
|
|
230
|
+
|
|
231
|
+
def reset(self):
|
|
232
|
+
"""
|
|
233
|
+
Clears all stored state.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
self._state.clear()
|