d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/core/types/data.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
from typing import TypeAlias, TypeVar
|
|
3
|
+
|
|
4
|
+
from .pytree import PyTree
|
|
5
|
+
|
|
6
|
+
TDataTree = TypeVar("TDataTree", bound=PyTree)
|
|
7
|
+
|
|
8
|
+
CollateFn: TypeAlias = Callable[[Sequence[TDataTree]], TDataTree]
|
|
9
|
+
"""
|
|
10
|
+
Type alias for a function that collates a sequence of samples into a batch.
|
|
11
|
+
|
|
12
|
+
The function receives a sequence of individual data point structures (PyTrees)
|
|
13
|
+
and is responsible for stacking or merging them into a single batched structure.
|
|
14
|
+
"""
|
d9d/core/types/pytree.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import TypeAlias, TypeVar
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
TLeaf = TypeVar("TLeaf")
|
|
6
|
+
|
|
7
|
+
PyTree: TypeAlias = TLeaf | list["PyTree[TLeaf]"] | dict[str, "PyTree[TLeaf]"] | tuple["PyTree[TLeaf]", ...]
|
|
8
|
+
"""
|
|
9
|
+
A recursive type definition representing a tree of data.
|
|
10
|
+
|
|
11
|
+
This type alias covers standard Python containers (dictionaries, lists, tuples)
|
|
12
|
+
nested arbitrarily deep, terminating in a leaf node of type `TLeaf`.
|
|
13
|
+
|
|
14
|
+
This is commonly used for handling nested state dictionaries or arguments
|
|
15
|
+
passed to functions that support recursive traversal (similar to `torch.utils._pytree`).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
TensorTree: TypeAlias = PyTree[torch.Tensor]
|
|
19
|
+
"""
|
|
20
|
+
A recursive tree structure where the leaf nodes are PyTorch Tensors.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
ScalarTree: TypeAlias = PyTree[str | float | int | bool]
|
|
24
|
+
"""
|
|
25
|
+
A recursive tree structure where the leaf nodes are python scalars (str, float, int).
|
|
26
|
+
"""
|
d9d/dataset/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This package provides utilities and torch.utils.data.Dataset implementations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .buffer_sorted import BufferSortedDataset, DatasetImplementingSortKeyProtocol
|
|
6
|
+
from .padding import PaddingSide1D, pad_stack_1d
|
|
7
|
+
from .sharded import ShardedDataset, ShardIndexingMode, shard_dataset_data_parallel
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BufferSortedDataset",
|
|
11
|
+
"DatasetImplementingSortKeyProtocol",
|
|
12
|
+
"PaddingSide1D",
|
|
13
|
+
"ShardIndexingMode",
|
|
14
|
+
"ShardedDataset",
|
|
15
|
+
"pad_stack_1d",
|
|
16
|
+
"shard_dataset_data_parallel"
|
|
17
|
+
]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import pickle # noqa: S403
|
|
2
|
+
import random
|
|
3
|
+
from typing import Any, Protocol, TypeVar
|
|
4
|
+
|
|
5
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DatasetImplementingSortKeyProtocol(Protocol[_T_co]):
|
|
12
|
+
"""
|
|
13
|
+
Protocol for datasets that support retrieval of a specific key for sorting purposes.
|
|
14
|
+
|
|
15
|
+
This is typically used for length-based bucketing/sorting where the dataset
|
|
16
|
+
needs to expose the length of an item without loading the full item.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __len__(self) -> int:
|
|
20
|
+
"""Returns the total number of items in the dataset."""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
def sort_key(self, index: int) -> Any:
|
|
24
|
+
"""
|
|
25
|
+
Returns a value used for sorting the dataset at the given index.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
index: The index of the item.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A comparable value (e.g., int length) used for sorting.
|
|
32
|
+
"""
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
def __getitem__(self, item: int) -> _T_co:
|
|
36
|
+
"""Retrieves the item at the specific index."""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BufferSortedDataset(Dataset[_T_co], Stateful):
|
|
41
|
+
"""
|
|
42
|
+
A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.
|
|
43
|
+
|
|
44
|
+
This prevents extreme padding in variable-length training (by grouping similar lengths)
|
|
45
|
+
while maintaining enough randomness to ensure statistical variance in updates.
|
|
46
|
+
|
|
47
|
+
Algorithm:
|
|
48
|
+
|
|
49
|
+
1. Select a range of indices (size `buffer_size`).
|
|
50
|
+
2. Sort these indices based on `base_dataset.sort_key()`.
|
|
51
|
+
3. Break the sorted list into packs of size `pack_size`.
|
|
52
|
+
4. Shuffle the order of these packs.
|
|
53
|
+
5. Flatten the list and serve items.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
base_dataset: DatasetImplementingSortKeyProtocol[_T_co],
|
|
59
|
+
buffer_size: int,
|
|
60
|
+
pack_size: int,
|
|
61
|
+
init_seed: int | None = None
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Constructs a BufferSortedDataset object.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
base_dataset: The underlying dataset implementing the `DatasetImplementingSortKeyProtocol` protocol.
|
|
68
|
+
buffer_size: The number of items to load into the buffer for sorting.
|
|
69
|
+
pack_size: The size of local groups (batches/micro-batches) that remain
|
|
70
|
+
contiguous after sorting, but are shuffled relative to other packs.
|
|
71
|
+
init_seed: Seed for the random number generator used for shuffling packs.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
self._base_dataset = base_dataset
|
|
75
|
+
self._buffer_size = buffer_size
|
|
76
|
+
self._pack_size = pack_size
|
|
77
|
+
|
|
78
|
+
self._rng = random.Random(init_seed ^ 0x105E7 if init_seed is not None else None)
|
|
79
|
+
self._buffer_indices: list[int] = []
|
|
80
|
+
self._buffer_idx: int = -1
|
|
81
|
+
|
|
82
|
+
def _update_buffer_idx(self, buffer_idx: int):
|
|
83
|
+
select_start = buffer_idx * self._buffer_size
|
|
84
|
+
select_end = (buffer_idx + 1) * self._buffer_size
|
|
85
|
+
select_end = min(select_end, len(self._base_dataset))
|
|
86
|
+
|
|
87
|
+
base_idx = list(range(select_start, select_end))
|
|
88
|
+
base_sort_keys = [self._base_dataset.sort_key(idx) for idx in range(select_start, select_end)]
|
|
89
|
+
|
|
90
|
+
local_idx = list(range(len(base_idx)))
|
|
91
|
+
local_idx = sorted(local_idx, key=lambda local_id: base_sort_keys[local_id])
|
|
92
|
+
|
|
93
|
+
local_idx_batch = [
|
|
94
|
+
local_idx[i: i + self._pack_size]
|
|
95
|
+
for i in range(0, len(local_idx), self._pack_size)
|
|
96
|
+
]
|
|
97
|
+
self._rng.shuffle(local_idx_batch)
|
|
98
|
+
local_idx = [y for x in local_idx_batch for y in x]
|
|
99
|
+
|
|
100
|
+
self._buffer_indices = [base_idx[local_id] for local_id in local_idx]
|
|
101
|
+
|
|
102
|
+
self._buffer_idx = buffer_idx
|
|
103
|
+
|
|
104
|
+
def __getitem__(self, index: int) -> _T_co:
|
|
105
|
+
"""
|
|
106
|
+
Retrieves an item from the locally sorted/shuffled buffer.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
index: The global index.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The dataset item.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
needs_buffer_idx = index // self._buffer_size
|
|
116
|
+
if self._buffer_idx != needs_buffer_idx:
|
|
117
|
+
self._update_buffer_idx(needs_buffer_idx)
|
|
118
|
+
|
|
119
|
+
take_id = self._buffer_indices[index % self._buffer_size]
|
|
120
|
+
|
|
121
|
+
return self._base_dataset[take_id]
|
|
122
|
+
|
|
123
|
+
def __len__(self) -> int:
|
|
124
|
+
"""Returns the length of the base dataset."""
|
|
125
|
+
|
|
126
|
+
return len(self._base_dataset)
|
|
127
|
+
|
|
128
|
+
def state_dict(self) -> dict[str, Any]:
|
|
129
|
+
ret = {
|
|
130
|
+
"seed": pickle.dumps(self._rng.getstate()),
|
|
131
|
+
"buffer_idx": self._buffer_idx,
|
|
132
|
+
"buffer_indices": self._buffer_indices,
|
|
133
|
+
}
|
|
134
|
+
if isinstance(self._base_dataset, Stateful):
|
|
135
|
+
ret["base_dataset"] = self._base_dataset.state_dict()
|
|
136
|
+
return ret
|
|
137
|
+
|
|
138
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
139
|
+
self._rng.setstate(pickle.loads(state_dict["seed"])) # noqa: S301
|
|
140
|
+
self._buffer_idx = state_dict["buffer_idx"]
|
|
141
|
+
self._buffer_indices = state_dict["buffer_indices"]
|
|
142
|
+
if isinstance(self._base_dataset, Stateful):
|
|
143
|
+
self._base_dataset.load_state_dict(state_dict["base_dataset"])
|
d9d/dataset/padding.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from enum import StrEnum
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PaddingSide1D(StrEnum):
|
|
9
|
+
"""
|
|
10
|
+
Enum specifying the side for padding 1D sequences.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
left: Pad on the left side.
|
|
14
|
+
right: Pad on the right side.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
left = "left"
|
|
18
|
+
right = "right"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _padding_side_1d_to_config(side: PaddingSide1D, difference: int) -> tuple[int, ...]:
|
|
22
|
+
match side:
|
|
23
|
+
case PaddingSide1D.left:
|
|
24
|
+
return difference, 0
|
|
25
|
+
case PaddingSide1D.right:
|
|
26
|
+
return 0, difference
|
|
27
|
+
case _:
|
|
28
|
+
raise ValueError("Unknown padding side")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def pad_stack_1d(
|
|
32
|
+
items: Sequence[torch.Tensor],
|
|
33
|
+
pad_value: int,
|
|
34
|
+
padding_side: PaddingSide1D = PaddingSide1D.right,
|
|
35
|
+
pad_to_multiple_of: int | None = None
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
"""
|
|
38
|
+
Stacks 1D tensors into a batch, applying padding.
|
|
39
|
+
|
|
40
|
+
Calculates the maximum length among the input tensors (optionally aligning to a multiple),
|
|
41
|
+
pads elements to match this length on the specified side, and stacks them.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
items: A sequence of 1D tensors to be stacked.
|
|
45
|
+
pad_value: The value used for padding.
|
|
46
|
+
padding_side: The side on which to apply padding (left or right).
|
|
47
|
+
pad_to_multiple_of: Optional integer. If provided, ensures the target length
|
|
48
|
+
is a multiple of this value.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A single stacked tensor of shape (batch, max_length).
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If no items are provided or if `pad_to_multiple_of` is <= 0.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
if not items:
|
|
58
|
+
raise ValueError("Cannot stack 0 items")
|
|
59
|
+
if pad_to_multiple_of is not None and pad_to_multiple_of <= 0:
|
|
60
|
+
raise ValueError("pad_to_multiple_of should be > 0")
|
|
61
|
+
|
|
62
|
+
max_len = max(x.shape[0] for x in items)
|
|
63
|
+
|
|
64
|
+
if pad_to_multiple_of is not None and (remainder := max_len % pad_to_multiple_of) != 0:
|
|
65
|
+
max_len = max_len + (pad_to_multiple_of - remainder)
|
|
66
|
+
|
|
67
|
+
padded_items = []
|
|
68
|
+
|
|
69
|
+
for x in items:
|
|
70
|
+
difference = max_len - x.shape[0]
|
|
71
|
+
|
|
72
|
+
if difference == 0:
|
|
73
|
+
padded_items.append(x)
|
|
74
|
+
else:
|
|
75
|
+
padded_items.append(
|
|
76
|
+
F.pad(x, _padding_side_1d_to_config(padding_side, difference), value=pad_value)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return torch.stack(padded_items, dim=0)
|
d9d/dataset/sharded.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Sized
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
|
|
6
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
7
|
+
from torch.utils.data import Dataset
|
|
8
|
+
|
|
9
|
+
from d9d.core.dist_context import BATCH_DOMAIN, DistributedContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ShardIndexingMode(StrEnum):
|
|
13
|
+
"""
|
|
14
|
+
Defines how the dataset is split across shards.
|
|
15
|
+
|
|
16
|
+
Modes:
|
|
17
|
+
sequential: Round-robin distribution.
|
|
18
|
+
|
|
19
|
+
shard0: 0, 4, 8, 12
|
|
20
|
+
shard1: 1, 5, 9, 13
|
|
21
|
+
shard2: 2, 6, 10
|
|
22
|
+
shard3: 3, 7, 11
|
|
23
|
+
|
|
24
|
+
chunked: Contiguous blocks.
|
|
25
|
+
|
|
26
|
+
shard0: 0, 1, 2, 3
|
|
27
|
+
shard1: 4, 5, 6, 7
|
|
28
|
+
shard2: 8, 9, 10, 11
|
|
29
|
+
shard3: 12, 13
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
sequential = "sequential"
|
|
33
|
+
chunked = "chunked"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ShardedDataset(Dataset[_T_co], Stateful):
|
|
40
|
+
"""
|
|
41
|
+
A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.
|
|
42
|
+
|
|
43
|
+
This is useful for Data Parallel training where each process should only see
|
|
44
|
+
a subset of the data. It supports different indexing modes and optional padding
|
|
45
|
+
to ensure all shards have equal length (preventing hangs in distributed collectives).
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
dataset: Dataset[_T_co],
|
|
51
|
+
total_shards: int,
|
|
52
|
+
current_shard: int,
|
|
53
|
+
indexing_mode: ShardIndexingMode,
|
|
54
|
+
pad_to_equal_size_across_shards: bool
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Constructs a ShardedDataset object.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
dataset: The underlying dataset to shard.
|
|
61
|
+
total_shards: The total number of shards (e.g., number of DP ranks).
|
|
62
|
+
current_shard: The index of the current shard (e.g., current DP rank).
|
|
63
|
+
indexing_mode: How indices are assigned to shards (sequential/round-robin or chunked).
|
|
64
|
+
pad_to_equal_size_across_shards: If True, the length of the dataset will be padded
|
|
65
|
+
so that all shards report the same length. The last standard element is repeated.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
if not isinstance(dataset, Sized):
|
|
69
|
+
raise ValueError("Dataset should implement __len__ method")
|
|
70
|
+
|
|
71
|
+
self._dataset = dataset
|
|
72
|
+
|
|
73
|
+
self._total_shards = total_shards
|
|
74
|
+
self._current_shard = current_shard
|
|
75
|
+
|
|
76
|
+
self._indexing_mode = indexing_mode
|
|
77
|
+
self._pad_to_equal_size_across_shards = pad_to_equal_size_across_shards
|
|
78
|
+
|
|
79
|
+
def _compute_real_index_sequential(self, index: int) -> int:
|
|
80
|
+
return index * self._total_shards + self._current_shard
|
|
81
|
+
|
|
82
|
+
def _get_base_index_unsafe(self, index: int) -> int:
|
|
83
|
+
"""
|
|
84
|
+
Calculates the underlying dataset index for a given shard index,
|
|
85
|
+
without boundary checking.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
match self._indexing_mode:
|
|
89
|
+
case ShardIndexingMode.sequential:
|
|
90
|
+
base_index = index * self._total_shards + self._current_shard
|
|
91
|
+
|
|
92
|
+
return base_index
|
|
93
|
+
case ShardIndexingMode.chunked:
|
|
94
|
+
ceil_len = math.ceil(len(self._dataset) / self._total_shards)
|
|
95
|
+
shard_start_offset = ceil_len * self._current_shard
|
|
96
|
+
|
|
97
|
+
return shard_start_offset + index
|
|
98
|
+
case _:
|
|
99
|
+
raise ValueError(f"Unknown shard indexing mode: {self._indexing_mode}")
|
|
100
|
+
|
|
101
|
+
def __getitem__(self, index: int) -> _T_co:
|
|
102
|
+
"""
|
|
103
|
+
Retrieves an item from the underlying dataset mapping logic shard index to physical index.
|
|
104
|
+
|
|
105
|
+
If padding is enabled and the index exceeds the valid data for this shard,
|
|
106
|
+
the last item in the dataset is returned.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
index: The index relative to this shard.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The data item.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
base_index = self._get_base_index_unsafe(index)
|
|
116
|
+
if base_index >= len(self._dataset):
|
|
117
|
+
base_index = len(self._dataset) - 1
|
|
118
|
+
return self._dataset[base_index]
|
|
119
|
+
|
|
120
|
+
def __len__(self) -> int:
|
|
121
|
+
"""
|
|
122
|
+
Returns the number of items in this specific shard.
|
|
123
|
+
|
|
124
|
+
If `pad_to_equal_size_across_shards` is True, this returns the ceiling
|
|
125
|
+
length (max length across all shards).
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
ceil_len = math.ceil(len(self._dataset) / self._total_shards)
|
|
129
|
+
|
|
130
|
+
if self._pad_to_equal_size_across_shards:
|
|
131
|
+
return ceil_len
|
|
132
|
+
|
|
133
|
+
shards_remainder = len(self._dataset) % self._total_shards
|
|
134
|
+
match self._indexing_mode:
|
|
135
|
+
case ShardIndexingMode.sequential:
|
|
136
|
+
shards_full = len(self._dataset) // self._total_shards
|
|
137
|
+
return shards_full + 1 if self._current_shard < shards_remainder else shards_full
|
|
138
|
+
case ShardIndexingMode.chunked:
|
|
139
|
+
is_shard_last = self._current_shard == self._total_shards - 1
|
|
140
|
+
if not is_shard_last or shards_remainder == 0:
|
|
141
|
+
return ceil_len
|
|
142
|
+
else:
|
|
143
|
+
return ceil_len - (self._total_shards - shards_remainder)
|
|
144
|
+
|
|
145
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
146
|
+
if isinstance(self._dataset, Stateful):
|
|
147
|
+
self._dataset.load_state_dict(state_dict["dataset"])
|
|
148
|
+
|
|
149
|
+
# check whether env mismatched
|
|
150
|
+
if state_dict["total_shards"] != self._total_shards:
|
|
151
|
+
raise ValueError("Shard count mismatch")
|
|
152
|
+
self._total_shards = state_dict["total_shards"]
|
|
153
|
+
|
|
154
|
+
self._current_shard = state_dict["current_shard"]
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
dct: dict[str, Any] = {
|
|
158
|
+
"total_shards": self._total_shards,
|
|
159
|
+
"current_shard": self._current_shard
|
|
160
|
+
}
|
|
161
|
+
if isinstance(self._dataset, Stateful):
|
|
162
|
+
dct["dataset"] = self._dataset.state_dict()
|
|
163
|
+
return dct
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def shard_dataset_data_parallel(
|
|
167
|
+
dataset: Dataset[_T_co],
|
|
168
|
+
dist_context: DistributedContext,
|
|
169
|
+
indexing_mode: ShardIndexingMode = ShardIndexingMode.sequential,
|
|
170
|
+
pad_to_equal_size_across_shards: bool = True
|
|
171
|
+
) -> Dataset[_T_co]:
|
|
172
|
+
"""
|
|
173
|
+
Wraps a dataset into a ShardedDataset based on the Data Parallel dimension of the distributed context.
|
|
174
|
+
|
|
175
|
+
This is a helper function to automatically determine the correct rank and world size
|
|
176
|
+
from the 'dp' (Data Parallel) mesh dimension within the batch domain DeviceMesh.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
dataset: The source dataset to shard.
|
|
180
|
+
dist_context: The distributed context.
|
|
181
|
+
indexing_mode: The strategy for splitting data indices (sequential/round-robin or chunked).
|
|
182
|
+
pad_to_equal_size_across_shards: If True, ensures all shards have the same length by padding.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
A dataset instance representing the local shard.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
dp_mesh = dist_context.mesh_for(BATCH_DOMAIN)["dp"]
|
|
189
|
+
return ShardedDataset(
|
|
190
|
+
dataset=dataset,
|
|
191
|
+
total_shards=dp_mesh.size(),
|
|
192
|
+
current_shard=dp_mesh.get_local_rank(),
|
|
193
|
+
indexing_mode=indexing_mode,
|
|
194
|
+
pad_to_equal_size_across_shards=pad_to_equal_size_across_shards
|
|
195
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.distributed.tensor
|
|
7
|
+
|
|
8
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def set_seeds(
|
|
12
|
+
dist_context: DistributedContext,
|
|
13
|
+
seed: int,
|
|
14
|
+
distinct_seed_mesh_dim: str = "pp",
|
|
15
|
+
) -> None:
|
|
16
|
+
"""
|
|
17
|
+
Sets random seeds for Python, NumPy, and PyTorch.
|
|
18
|
+
|
|
19
|
+
This function sets seeds deterministically based on the provided base seed and the
|
|
20
|
+
process's rank within a specific mesh dimension.
|
|
21
|
+
|
|
22
|
+
The seed is shifted by the rank in the `distinct_seed_mesh_dim` (e.g., Pipeline Parallel rank).
|
|
23
|
+
This ensures that processes in different pipeline stages operate with different random states,
|
|
24
|
+
while processes that should share randomness (like Expert Parallel peers) can be synchronized.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dist_context: The distributed context.
|
|
28
|
+
seed: The base random seed.
|
|
29
|
+
distinct_seed_mesh_dim: The name of the mesh dimension along which seeds should
|
|
30
|
+
be distinct (e.g., 'pp' for pipeline parallelism). Ranks along other dimensions
|
|
31
|
+
will share the seed.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Mutate seed based on PP rank if distributed
|
|
35
|
+
if dist_context.mesh_params.is_distributed:
|
|
36
|
+
distinct_mesh = dist_context.mesh_for(REGULAR_DOMAIN)[distinct_seed_mesh_dim]
|
|
37
|
+
seed = (seed + distinct_mesh.get_local_rank()) % 2**64
|
|
38
|
+
|
|
39
|
+
dist_context.logger.info(f"Set seed {seed}")
|
|
40
|
+
|
|
41
|
+
torch.manual_seed(seed)
|
|
42
|
+
os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
|
|
43
|
+
random.seed(seed)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import numpy as np # noqa: PLC0415
|
|
47
|
+
np.random.seed(seed)
|
|
48
|
+
except ImportError:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
# Set DTensor seeding if distributed
|
|
52
|
+
if dist_context.mesh_params.is_distributed:
|
|
53
|
+
mesh_regular = dist_context.mesh_for(REGULAR_DOMAIN)
|
|
54
|
+
duplicate_seed_mesh_dim = tuple(
|
|
55
|
+
name
|
|
56
|
+
for name
|
|
57
|
+
in cast(list[str], mesh_regular.mesh_dim_names)
|
|
58
|
+
if name != distinct_seed_mesh_dim
|
|
59
|
+
)
|
|
60
|
+
duplicate_seed_mesh = mesh_regular[duplicate_seed_mesh_dim] if len(duplicate_seed_mesh_dim) != 0 else None
|
|
61
|
+
|
|
62
|
+
if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None:
|
|
63
|
+
torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) # noqa: SLF001
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.distributed import DeviceMesh
|
|
9
|
+
from torch.distributed.tensor import DTensor, Shard
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
13
|
+
class GradNormGroup:
|
|
14
|
+
"""
|
|
15
|
+
Defines a group of parameters that share the same distributed properties.
|
|
16
|
+
|
|
17
|
+
This grouping is used to batch gradient norm reductions efficiently. Parameters
|
|
18
|
+
sharing the same device mesh shards can be reduced in a single communication collective.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
shard_meshes: A tuple of device meshes where the parameters are sharded, or None if replicated/local.
|
|
22
|
+
device: The device where parameters reside.
|
|
23
|
+
grad_dtype: The data type of the gradients.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
shard_meshes: tuple[DeviceMesh, ...] | None
|
|
27
|
+
device: torch.device
|
|
28
|
+
grad_dtype: torch.dtype | None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
ParametersForNorm = dict[GradNormGroup, list[nn.Parameter]]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _extract_shard_meshes(param: nn.Parameter) -> tuple[DeviceMesh, ...] | None:
|
|
35
|
+
data = param.data
|
|
36
|
+
|
|
37
|
+
if not isinstance(data, DTensor):
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
mesh = data.device_mesh
|
|
41
|
+
mesh_dim_names = mesh.mesh_dim_names
|
|
42
|
+
if mesh_dim_names is None:
|
|
43
|
+
raise ValueError("Only named meshes are supported.")
|
|
44
|
+
|
|
45
|
+
shard_placement_dim_names: list[str] = []
|
|
46
|
+
|
|
47
|
+
for dim_i, placement in enumerate(data.placements):
|
|
48
|
+
if isinstance(placement, Shard):
|
|
49
|
+
shard_placement_dim_names.append(mesh_dim_names[dim_i])
|
|
50
|
+
|
|
51
|
+
if len(shard_placement_dim_names) == 0:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
return tuple(mesh[name] for name in shard_placement_dim_names)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _group_sort_key(item: tuple[GradNormGroup, list[nn.Parameter]]) -> Any:
|
|
58
|
+
# put items WITH shard_meshes on top so they are processed first so we benefit from comm-comp overlap
|
|
59
|
+
return item[0].shard_meshes is None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def group_parameters_for_norm(parameters: Iterable[nn.Parameter]) -> ParametersForNorm:
|
|
63
|
+
"""
|
|
64
|
+
Groups parameters based on their distributed tensor characteristics.
|
|
65
|
+
|
|
66
|
+
Groups parameters by their sharding meshes, device, and gradient data type.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
parameters: The iterable of parameters to group.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
A dictionary mapping synchronization groups to lists of parameters.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
grouped_params: ParametersForNorm = defaultdict(list)
|
|
76
|
+
for param in parameters:
|
|
77
|
+
if not param.requires_grad:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
group = GradNormGroup(
|
|
81
|
+
shard_meshes=_extract_shard_meshes(param),
|
|
82
|
+
grad_dtype=param.grad_dtype,
|
|
83
|
+
device=param.device
|
|
84
|
+
)
|
|
85
|
+
grouped_params[group].append(param)
|
|
86
|
+
# we are sure dict is ordered in python 3.11 so we can sort it...
|
|
87
|
+
return dict(sorted(grouped_params.items(), key=_group_sort_key))
|