d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
2
|
+
from typing import Any, Self, TypedDict, Unpack
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils._pytree as pytree # noqa: PLC2701
|
|
6
|
+
from torch.utils.data import Dataset, Sampler
|
|
7
|
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
8
|
+
|
|
9
|
+
from d9d.core.dist_context import BATCH_DOMAIN, DistributedContext
|
|
10
|
+
from d9d.core.types import CollateFn, PyTree
|
|
11
|
+
from d9d.loop.config import DataLoadingConfig
|
|
12
|
+
from d9d.loop.control import DatasetProvider, InitializeDatasetContext
|
|
13
|
+
|
|
14
|
+
from .batch_maths import BatchMaths
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DataLoaderKwargs(TypedDict, total=False):
|
|
18
|
+
"""
|
|
19
|
+
Type definition for arguments accepted by the PyTorch DataLoader.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
batch_size: int | None
|
|
23
|
+
shuffle: bool | None
|
|
24
|
+
sampler: Sampler | Iterable | None
|
|
25
|
+
batch_sampler: Sampler[list] | Iterable[list] | None
|
|
26
|
+
num_workers: int
|
|
27
|
+
collate_fn: CollateFn
|
|
28
|
+
pin_memory: bool
|
|
29
|
+
drop_last: bool
|
|
30
|
+
timeout: float
|
|
31
|
+
worker_init_fn: Callable | None
|
|
32
|
+
multiprocessing_context: Any
|
|
33
|
+
generator: Any
|
|
34
|
+
prefetch_factor: int | None
|
|
35
|
+
persistent_workers: bool
|
|
36
|
+
pin_memory_device: str
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _move_to_device(data: PyTree, device: torch.types.Device) -> PyTree:
|
|
40
|
+
return pytree.tree_map(lambda x: x.to(device), data)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class IteratorBatchGroup(Iterator):
|
|
44
|
+
"""
|
|
45
|
+
An iterator that groups items from a base iterator into sub-streams.
|
|
46
|
+
|
|
47
|
+
This class is utilized for gradient accumulation where
|
|
48
|
+
a single optimizer step consumes multiple micro-batches (the group).
|
|
49
|
+
|
|
50
|
+
It also moves the data to the specified device immediately upon access.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
base: Iterator,
|
|
56
|
+
device: torch.types.Device,
|
|
57
|
+
batch_group_size: int
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Constructs an IteratorBatchGroup object.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
base: The underlying data iterator (usually from a DataLoader).
|
|
64
|
+
device: The target device to move tensors to.
|
|
65
|
+
batch_group_size: The number of micro-batches to yield within one group.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
self._base = base
|
|
69
|
+
self._device = device
|
|
70
|
+
|
|
71
|
+
self._batch_group_size = batch_group_size
|
|
72
|
+
|
|
73
|
+
self._is_end = False
|
|
74
|
+
|
|
75
|
+
def __next__(self) -> PyTree:
|
|
76
|
+
"""
|
|
77
|
+
Advances the iterator.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A generator that yields `batch_group_size` items (micro-batches),
|
|
81
|
+
with each item already moved to the configured device.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
StopIteration: If the underlying iterator is exhausted.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
if self._is_end:
|
|
88
|
+
raise StopIteration()
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
sample_item = next(self._base)
|
|
92
|
+
except StopIteration:
|
|
93
|
+
self._is_end = True
|
|
94
|
+
raise StopIteration() from None
|
|
95
|
+
|
|
96
|
+
def _iter_inside_group():
|
|
97
|
+
yield _move_to_device(sample_item, self._device)
|
|
98
|
+
|
|
99
|
+
for _ in range(self._batch_group_size - 1):
|
|
100
|
+
try:
|
|
101
|
+
item = next(self._base)
|
|
102
|
+
yield _move_to_device(item, self._device)
|
|
103
|
+
except StopIteration:
|
|
104
|
+
self._is_end = True
|
|
105
|
+
break
|
|
106
|
+
|
|
107
|
+
return _iter_inside_group()
|
|
108
|
+
|
|
109
|
+
def __iter__(self) -> Self:
|
|
110
|
+
"""Returns self."""
|
|
111
|
+
|
|
112
|
+
return self
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class StatefulDataLoaderDataParallelAware(StatefulDataLoader):
|
|
116
|
+
"""
|
|
117
|
+
A stateful data loader that is aware of data parallel ranks.
|
|
118
|
+
|
|
119
|
+
This loader extends the standard torchdata StatefulDataLoader to ensure
|
|
120
|
+
that checkpoints are saved with rank-specific keys.
|
|
121
|
+
|
|
122
|
+
It also wraps the iterator to support batch grouping for gradient accumulation and
|
|
123
|
+
automatically transfer data to bound device.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
dataset: Dataset,
|
|
129
|
+
dp_rank: int,
|
|
130
|
+
device: torch.types.Device,
|
|
131
|
+
group_size: int,
|
|
132
|
+
**kwargs: Unpack[DataLoaderKwargs]
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Constructs a StatefulDataLoaderDataParallelAware object.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
dataset: The dataset to load from.
|
|
139
|
+
dp_rank: The Data Parallel rank of the current process (used for state checkpointing).
|
|
140
|
+
device: The device to move data to.
|
|
141
|
+
group_size: The number of batches to group together (e.g., for gradient accumulation).
|
|
142
|
+
**kwargs: Standard arguments passed to the parent DataLoader.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
super().__init__(dataset, **kwargs)
|
|
146
|
+
self._dp_rank = dp_rank
|
|
147
|
+
self._device = device
|
|
148
|
+
self._group_size = group_size
|
|
149
|
+
|
|
150
|
+
def state_dict(self) -> dict[str, Any]:
|
|
151
|
+
return {
|
|
152
|
+
f"dp_{self._dp_rank}": super().state_dict()
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
156
|
+
super().load_state_dict(state_dict[f"dp_{self._dp_rank}"])
|
|
157
|
+
|
|
158
|
+
def __iter__(self) -> Iterator:
|
|
159
|
+
return IteratorBatchGroup(
|
|
160
|
+
super().__iter__(),
|
|
161
|
+
device=self._device,
|
|
162
|
+
batch_group_size=self._group_size
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class DataLoaderFactory:
|
|
167
|
+
"""
|
|
168
|
+
Factory class for creating configured DataLoaders.
|
|
169
|
+
|
|
170
|
+
This class centralizes the creation logic for training and inference
|
|
171
|
+
data loaders, applying configurations.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
dist_context: DistributedContext,
|
|
177
|
+
provider: DatasetProvider,
|
|
178
|
+
config_data_loading: DataLoadingConfig,
|
|
179
|
+
batch_maths: BatchMaths
|
|
180
|
+
):
|
|
181
|
+
"""
|
|
182
|
+
Constructs a DataLoaderFactory object.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
dist_context: The distributed context containing mesh and rank information.
|
|
186
|
+
provider: The provider callable that initializes the dataset and collator.
|
|
187
|
+
config_data_loading: Specific configuration for data loading.
|
|
188
|
+
batch_maths: BatchMaths object.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
self._dist_context = dist_context
|
|
192
|
+
self._provider = provider
|
|
193
|
+
|
|
194
|
+
self._config_data_loading = config_data_loading
|
|
195
|
+
|
|
196
|
+
self._batch_maths = batch_maths
|
|
197
|
+
|
|
198
|
+
def _build_dataloader(
|
|
199
|
+
self,
|
|
200
|
+
provider: DatasetProvider,
|
|
201
|
+
batch_size: int,
|
|
202
|
+
group_size: int,
|
|
203
|
+
drop_last: bool
|
|
204
|
+
) -> StatefulDataLoader:
|
|
205
|
+
result = provider(InitializeDatasetContext(
|
|
206
|
+
dist_context=self._dist_context,
|
|
207
|
+
batch_maths=self._batch_maths
|
|
208
|
+
))
|
|
209
|
+
|
|
210
|
+
return StatefulDataLoaderDataParallelAware(
|
|
211
|
+
result.dataset,
|
|
212
|
+
collate_fn=result.collator,
|
|
213
|
+
group_size=group_size,
|
|
214
|
+
num_workers=self._config_data_loading.num_workers,
|
|
215
|
+
persistent_workers=self._config_data_loading.persistent_workers,
|
|
216
|
+
pin_memory=self._config_data_loading.pin_memory,
|
|
217
|
+
batch_size=batch_size,
|
|
218
|
+
dp_rank=self._dist_context.mesh_for(BATCH_DOMAIN)["dp"].size(),
|
|
219
|
+
device="cuda",
|
|
220
|
+
drop_last=drop_last
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def build_dataloader_for_train_job(self) -> StatefulDataLoader:
|
|
224
|
+
"""
|
|
225
|
+
Builds and returns a StatefulDataLoader configured for training.
|
|
226
|
+
|
|
227
|
+
This loader is configured to drop the last incomplete batch and group
|
|
228
|
+
batches according to the gradient accumulation settings defined in
|
|
229
|
+
BatchMaths.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
A configured StatefulDataLoader instance.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
return self._build_dataloader(
|
|
236
|
+
self._provider,
|
|
237
|
+
batch_size=self._batch_maths.data_loader_batch_size,
|
|
238
|
+
group_size=self._batch_maths.num_microbatches_gradient_accumulation,
|
|
239
|
+
drop_last=True
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def build_dataloader_for_infer_job(self) -> StatefulDataLoader:
|
|
243
|
+
"""
|
|
244
|
+
Builds and returns a StatefulDataLoader configured for inference.
|
|
245
|
+
|
|
246
|
+
This loader processes batches one by one (group size of 1) and does
|
|
247
|
+
not drop the last batch.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
A configured StatefulDataLoader instance.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
return self._build_dataloader(
|
|
254
|
+
self._provider,
|
|
255
|
+
batch_size=self._batch_maths.data_loader_batch_size,
|
|
256
|
+
group_size=1,
|
|
257
|
+
drop_last=False
|
|
258
|
+
)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import time
|
|
3
|
+
from contextlib import AbstractContextManager
|
|
4
|
+
from types import TracebackType
|
|
5
|
+
from typing import Self
|
|
6
|
+
|
|
7
|
+
from d9d.core.dist_context import DistributedContext
|
|
8
|
+
from d9d.loop.config import GarbageCollectionConfig
|
|
9
|
+
|
|
10
|
+
from .stepper import Stepper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ManualGarbageCollector(AbstractContextManager):
|
|
14
|
+
"""
|
|
15
|
+
Manages efficient Python garbage collection during the training loop.
|
|
16
|
+
|
|
17
|
+
This context manager disables automatic garbage collection upon entry to prevent
|
|
18
|
+
unpredictable latency spikes during training steps. It allows performing
|
|
19
|
+
manual collection at specific intervals (periodic) or specific points (forced).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
dist_ctx: DistributedContext,
|
|
25
|
+
config: GarbageCollectionConfig,
|
|
26
|
+
step: Stepper
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Constructs the garbage collector manager.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
dist_ctx: The distributed context.
|
|
33
|
+
config: Configuration determining how often GC should run.
|
|
34
|
+
step: Stepper instance used to track the current training step.
|
|
35
|
+
"""
|
|
36
|
+
self._dist_ctx = dist_ctx
|
|
37
|
+
self._config = config
|
|
38
|
+
self._step = step
|
|
39
|
+
|
|
40
|
+
def __enter__(self) -> Self:
|
|
41
|
+
"""
|
|
42
|
+
Disables automatic garbage collection and performs an initial full collection.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The calling instance.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
gc.disable()
|
|
49
|
+
self._collect(generation=2)
|
|
50
|
+
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def __exit__(
|
|
54
|
+
self,
|
|
55
|
+
exc_type: type[BaseException] | None,
|
|
56
|
+
exc_value: BaseException | None,
|
|
57
|
+
traceback: TracebackType | None, /
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Re-enables automatic garbage collection and performs a final full collection.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
exc_type: The type of the exception raised (if any).
|
|
64
|
+
exc_value: The exception instance raised (if any).
|
|
65
|
+
traceback: The traceback object (if any).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
gc.enable()
|
|
69
|
+
self._collect(generation=2)
|
|
70
|
+
|
|
71
|
+
def collect_periodic(self):
|
|
72
|
+
"""
|
|
73
|
+
Triggers garbage collection if the current step matches the configured period.
|
|
74
|
+
|
|
75
|
+
This typically performs a faster (generation 1) collection rather than a full sweep.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
if self._step.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=False):
|
|
79
|
+
self._collect(generation=1)
|
|
80
|
+
|
|
81
|
+
def collect_forced(self):
|
|
82
|
+
"""
|
|
83
|
+
Forces a full garbage collection run regardless of the step count.
|
|
84
|
+
|
|
85
|
+
This performs a generation 2 collection.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
self._collect(generation=2)
|
|
89
|
+
|
|
90
|
+
def _collect(self, generation: int):
|
|
91
|
+
begin = time.monotonic()
|
|
92
|
+
gc.collect(generation)
|
|
93
|
+
end = time.monotonic()
|
|
94
|
+
self._dist_ctx.logger.info(f"[GC] Garbage collection for generation {generation} took {end - begin}s")
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
|
|
3
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
4
|
+
from d9d.internals.grad_norm import ParametersForNorm, clip_grad_norm_distributed_, group_parameters_for_norm
|
|
5
|
+
from d9d.loop.config import GradientClippingConfig
|
|
6
|
+
from d9d.tracker import BaseTrackerRun
|
|
7
|
+
|
|
8
|
+
from .model_stage_factory import TrackedModules
|
|
9
|
+
from .stepper import Stepper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GradientClipper:
|
|
13
|
+
"""
|
|
14
|
+
Manages gradient clipping and logging of gradient norms in a distributed execution environment.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
dist_context: DistributedContext,
|
|
20
|
+
tracked_modules: TrackedModules,
|
|
21
|
+
config: GradientClippingConfig,
|
|
22
|
+
stepper: Stepper
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Constructs the gradient clipper.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
dist_context: The distributed context.
|
|
29
|
+
tracked_modules: Container of model modules whose parameters need clipping.
|
|
30
|
+
config: Configuration defining max norm and logging frequency.
|
|
31
|
+
stepper: Stepper instance used to track the current training step.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
self._dist_context = dist_context
|
|
35
|
+
self._tracked_modules = tracked_modules
|
|
36
|
+
self._config = config
|
|
37
|
+
self._stepper = stepper
|
|
38
|
+
|
|
39
|
+
self._parameter_groups: ParametersForNorm | None = None
|
|
40
|
+
|
|
41
|
+
def _all_parameters(self):
|
|
42
|
+
for model in self._tracked_modules.modules:
|
|
43
|
+
yield from model.parameters()
|
|
44
|
+
|
|
45
|
+
@contextmanager
|
|
46
|
+
def install(self):
|
|
47
|
+
"""
|
|
48
|
+
Context manager that prepares and groups parameters for efficient norm calculation.
|
|
49
|
+
|
|
50
|
+
It calculates necessary metadata (such as segregating shared parameters) to ensure
|
|
51
|
+
correct global norm calculation across the pipeline parallel mesh.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
self._parameter_groups = group_parameters_for_norm(self._all_parameters())
|
|
55
|
+
yield
|
|
56
|
+
self._parameter_groups = None
|
|
57
|
+
|
|
58
|
+
def clip_and_log(self, run: BaseTrackerRun):
|
|
59
|
+
"""
|
|
60
|
+
Clips gradients to the configured maximum norm and logs the total L2 norm.
|
|
61
|
+
|
|
62
|
+
This method performs an in-place modification of parameter gradients if a
|
|
63
|
+
maximum norm is configured. It calculates the global gradient norm across
|
|
64
|
+
distributed ranks.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
run: The tracker run instance used for logging the norm scalar.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
ValueError: If called outside the ``install`` context manager scope.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
should_log = self._stepper.should_do_action(self._config.log_total_steps)
|
|
74
|
+
|
|
75
|
+
if not self._config.max_norm and not should_log:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if self._parameter_groups is None:
|
|
79
|
+
raise ValueError("Parameter groups are not configured")
|
|
80
|
+
|
|
81
|
+
grad_norm = clip_grad_norm_distributed_(
|
|
82
|
+
parameter_groups=self._parameter_groups,
|
|
83
|
+
max_norm=self._config.max_norm,
|
|
84
|
+
norm_type=2.0,
|
|
85
|
+
pp_mesh=self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if should_log:
|
|
89
|
+
run.scalar(name="l2_grad_norm_total", value=grad_norm.item())
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.distributed.tensor import DTensor
|
|
5
|
+
|
|
6
|
+
from d9d.core.dist_context import DistributedContext
|
|
7
|
+
from d9d.internals.grad_sync import GradientSynchronizer
|
|
8
|
+
from d9d.loop.config import GradientManagerConfig
|
|
9
|
+
from d9d.metric.impl import WeightedMeanMetric
|
|
10
|
+
|
|
11
|
+
from .batch_maths import BatchMaths
|
|
12
|
+
from .model_stage_factory import TrackedModules
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GradientManager:
|
|
16
|
+
"""
|
|
17
|
+
Manages the lifecycle of gradients during the training loop.
|
|
18
|
+
|
|
19
|
+
This class handles gradient synchronization across ranks,
|
|
20
|
+
gradient data type configuration, and loss scaling based on accumulated weights.
|
|
21
|
+
It orchestrates the `GradientSynchronizer` and ensures gradients are correctly
|
|
22
|
+
prepared before the optimizer step.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
dist_context: DistributedContext,
|
|
28
|
+
tracked_modules: TrackedModules,
|
|
29
|
+
batch_maths: BatchMaths,
|
|
30
|
+
config: GradientManagerConfig
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Constructs the GradientManager and initializes the internal synchronizer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
dist_context: The distributed context.
|
|
37
|
+
tracked_modules: Container of model modules to manage gradients for.
|
|
38
|
+
batch_maths: Calculation utility for batch sizes and accumulation steps.
|
|
39
|
+
config: Configuration for gradient handling.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
self._dist_context = dist_context
|
|
43
|
+
self._tracked_modules = tracked_modules
|
|
44
|
+
self._batch_maths = batch_maths
|
|
45
|
+
self._config = config
|
|
46
|
+
self._loss = WeightedMeanMetric()
|
|
47
|
+
self._loss.to("cuda")
|
|
48
|
+
|
|
49
|
+
self._grad_sync = GradientSynchronizer(
|
|
50
|
+
[list(module.parameters()) for module in self._tracked_modules.modules],
|
|
51
|
+
bucket_size_mb=self._config.bucket_size_mb,
|
|
52
|
+
require_accumulations=self._batch_maths.num_backward_calls
|
|
53
|
+
)
|
|
54
|
+
self._grads_to_scale: list[torch.Tensor] | None = None
|
|
55
|
+
|
|
56
|
+
def _setup_grad_dtype(self):
|
|
57
|
+
if self._config.grad_dtype is None:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
for mod in self._tracked_modules.modules:
|
|
61
|
+
for param in mod.parameters():
|
|
62
|
+
if param.requires_grad:
|
|
63
|
+
param.grad_dtype = getattr(torch, self._config.grad_dtype)
|
|
64
|
+
|
|
65
|
+
def _bind_grads_to_scale(self):
|
|
66
|
+
grads_to_scale: list[torch.Tensor] = []
|
|
67
|
+
|
|
68
|
+
for mod in self._tracked_modules.modules:
|
|
69
|
+
for param in mod.parameters():
|
|
70
|
+
if param.grad is None:
|
|
71
|
+
continue
|
|
72
|
+
grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad
|
|
73
|
+
grads_to_scale.append(grad)
|
|
74
|
+
|
|
75
|
+
self._grads_to_scale = grads_to_scale
|
|
76
|
+
|
|
77
|
+
def _unbind_grads_to_scale(self):
|
|
78
|
+
self._grads_to_scale = None
|
|
79
|
+
|
|
80
|
+
def _scale_grads(self):
|
|
81
|
+
scale_factor = 1.0 / self._loss.accumulated_weight
|
|
82
|
+
torch._foreach_mul_(self._grads_to_scale, scale_factor)
|
|
83
|
+
|
|
84
|
+
@contextmanager
|
|
85
|
+
def install(self):
|
|
86
|
+
"""
|
|
87
|
+
Context manager to activate gradient handling for a forward/backward pass.
|
|
88
|
+
|
|
89
|
+
This sets up gradient dtypes, install backward hooks for synchronization via
|
|
90
|
+
the `GradientSynchronizer`, and binds gradients for later scaling. It acts
|
|
91
|
+
as the boundary for the accumulation phase.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
self._setup_grad_dtype()
|
|
95
|
+
self._grad_sync.bind()
|
|
96
|
+
self._bind_grads_to_scale()
|
|
97
|
+
yield
|
|
98
|
+
self._unbind_grads_to_scale()
|
|
99
|
+
self._grad_sync.unbind()
|
|
100
|
+
|
|
101
|
+
def add_loss_with_weight(self, loss: torch.Tensor, loss_weight: torch.Tensor):
|
|
102
|
+
"""
|
|
103
|
+
Accumulates a loss value and its corresponding weight into the internal metric.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
loss: The computed loss scalar.
|
|
107
|
+
loss_weight: The weight asscociated with this loss.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
self._loss.update(loss, loss_weight)
|
|
111
|
+
|
|
112
|
+
def sync_and_scale(self):
|
|
113
|
+
"""
|
|
114
|
+
Finalizes gradients to be ready for the optimizer step.
|
|
115
|
+
|
|
116
|
+
This method performs the following operations:
|
|
117
|
+
|
|
118
|
+
1. Waits for all gradient synchronization hooks to complete.
|
|
119
|
+
2. Synchronizes the accumulated loss/weights across the distributed context.
|
|
120
|
+
3. Scales the gradients by the inverse of the total accumulated weight to
|
|
121
|
+
normalize them.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
self._grad_sync.wait()
|
|
125
|
+
|
|
126
|
+
self._loss.trigger_sync(self._dist_context)
|
|
127
|
+
self._loss.wait_sync(self._dist_context)
|
|
128
|
+
self._scale_grads()
|
|
129
|
+
|
|
130
|
+
def compute_global_loss(self) -> torch.Tensor:
|
|
131
|
+
"""
|
|
132
|
+
Calculates the final weighted mean loss.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
The averaged loss scalar across all accumulation steps and ranks.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
return self._loss.compute()
|
|
139
|
+
|
|
140
|
+
def zero_grad(self):
|
|
141
|
+
"""
|
|
142
|
+
Resets the internal state for the next training step.
|
|
143
|
+
|
|
144
|
+
This clears the accumulated gradients in the synchronizer and resets the
|
|
145
|
+
loss metrics.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
self._grad_sync.zero_grad()
|
|
149
|
+
self._loss.reset()
|