d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ModelStateIndexMeta(BaseModel):
|
|
5
|
+
"""
|
|
6
|
+
Metadata for the model state index.
|
|
7
|
+
|
|
8
|
+
Attributes:
|
|
9
|
+
total_size: Total size of the model parameters in bytes.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
total_size: int
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelStateIndex(BaseModel):
|
|
16
|
+
"""
|
|
17
|
+
Represents the content of the `model.safetensors.index.json` file.
|
|
18
|
+
|
|
19
|
+
This index maps every weight name to the specific .safetensors file containing it.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
metadata: Global metadata about the checkpoint.
|
|
23
|
+
weight_map: Mapping from parameter name to filename.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
metadata: ModelStateIndexMeta
|
|
27
|
+
weight_map: dict[str, str]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
MODEL_STATE_INDEX_FILE_NAME = "model.safetensors.index.json"
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.distributed.tensor import DTensor
|
|
6
|
+
|
|
7
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
8
|
+
from d9d.model_state.mapper.compose import (
|
|
9
|
+
ModelStateMapperParallel,
|
|
10
|
+
ModelStateMapperSequential,
|
|
11
|
+
)
|
|
12
|
+
from d9d.model_state.mapper.leaf import (
|
|
13
|
+
ModelStateMapperDistribute,
|
|
14
|
+
ModelStateMapperIdentity,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .reader import read_model_state
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _build_injection_mapper(name: str, state: torch.Tensor) -> ModelStateMapper:
|
|
21
|
+
if isinstance(state, DTensor):
|
|
22
|
+
return ModelStateMapperDistribute(name=name, placements=state.placements, device_mesh=state.device_mesh)
|
|
23
|
+
else:
|
|
24
|
+
return ModelStateMapperIdentity(name)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _augment_mapper_for_injection(model: nn.Module, mapper: ModelStateMapper) -> ModelStateMapper:
|
|
28
|
+
states_to_load = {output for group in mapper.state_dependency_groups() for output in group.outputs}
|
|
29
|
+
current_state_dict = model.state_dict()
|
|
30
|
+
mapper = ModelStateMapperSequential([
|
|
31
|
+
mapper,
|
|
32
|
+
ModelStateMapperParallel([_build_injection_mapper(name, current_state_dict[name]) for name in states_to_load])
|
|
33
|
+
])
|
|
34
|
+
return mapper
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_model_state(
|
|
38
|
+
src_dir: Path,
|
|
39
|
+
mapper: ModelStateMapper,
|
|
40
|
+
device: str,
|
|
41
|
+
model: nn.Module,
|
|
42
|
+
show_progress: bool = True,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
High-level utility to stream a checkpoint directly into a PyTorch module.
|
|
46
|
+
|
|
47
|
+
This function orchestrates the full loading lifecycle:
|
|
48
|
+
|
|
49
|
+
1. Topology Mapping: Uses `mapper` to rename/stack/reshape on-disk states to model states.
|
|
50
|
+
|
|
51
|
+
2. Automatic Distribution: If the `model` contains `DTensor`s, the loaded local tensors are automatically
|
|
52
|
+
sharded/replicated to match the model's placement schema.
|
|
53
|
+
|
|
54
|
+
3. Streaming Read & Inject: After loading and transforming a model state, it will be injected into `model`
|
|
55
|
+
using `load_state_dict(...)`.
|
|
56
|
+
|
|
57
|
+
NOTICE: Only states specified in `mapper` will be loaded! You can use
|
|
58
|
+
`d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will load every
|
|
59
|
+
model state without changing it.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
src_dir: Directory containing .safetensors and index files.
|
|
63
|
+
mapper: The topology defining how mapping from disk keys to model keys works.
|
|
64
|
+
device: The device to load tensors onto (usually "cpu" or "cuda").
|
|
65
|
+
model: The model instance to load weights into.
|
|
66
|
+
show_progress: Whether to display the loading progress bar.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
for state_name, state_value in read_model_state(
|
|
70
|
+
src_dir=src_dir,
|
|
71
|
+
mapper=_augment_mapper_for_injection(model, mapper),
|
|
72
|
+
device=device,
|
|
73
|
+
show_progress=show_progress
|
|
74
|
+
):
|
|
75
|
+
model.load_state_dict({state_name: state_value}, strict=False)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.distributed import DeviceMesh
|
|
7
|
+
from torch.distributed.tensor import DTensor
|
|
8
|
+
|
|
9
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
10
|
+
from d9d.model_state.mapper.compose import (
|
|
11
|
+
ModelStateMapperParallel,
|
|
12
|
+
ModelStateMapperSequential,
|
|
13
|
+
)
|
|
14
|
+
from d9d.model_state.mapper.leaf import (
|
|
15
|
+
ModelStateMapperGatherFullTensor,
|
|
16
|
+
ModelStateMapperIdentity,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from .writer import (
|
|
20
|
+
write_model_state_local,
|
|
21
|
+
write_model_state_pipeline_parallel,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _build_extraction_mapper(name: str, state: torch.Tensor) -> ModelStateMapper:
|
|
26
|
+
if isinstance(state, DTensor):
|
|
27
|
+
return ModelStateMapperGatherFullTensor(name)
|
|
28
|
+
else:
|
|
29
|
+
return ModelStateMapperIdentity(name)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _augment_mapper_for_extraction(models: list[nn.Module], mapper: ModelStateMapper) -> ModelStateMapper:
|
|
33
|
+
states_to_save = {input_state for group in mapper.state_dependency_groups() for input_state in group.inputs}
|
|
34
|
+
|
|
35
|
+
current_state_dict = {}
|
|
36
|
+
for model in models:
|
|
37
|
+
current_state_dict.update(model.state_dict())
|
|
38
|
+
mapper = ModelStateMapperSequential([
|
|
39
|
+
ModelStateMapperParallel([_build_extraction_mapper(name, current_state_dict[name]) for name in states_to_save]),
|
|
40
|
+
mapper
|
|
41
|
+
])
|
|
42
|
+
return mapper
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _state_generator(models: list[nn.Module]) -> Iterable[tuple[str, torch.Tensor]]:
|
|
46
|
+
for model in models:
|
|
47
|
+
yield from model.state_dict().items()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def save_model_state(
|
|
51
|
+
dest_dir: Path,
|
|
52
|
+
mapper: ModelStateMapper,
|
|
53
|
+
model: nn.Module,
|
|
54
|
+
shard_size_gb: float = 4.0,
|
|
55
|
+
show_progress: bool = True
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
High-level utility to save a PyTorch model to disk on a **single** process.
|
|
59
|
+
|
|
60
|
+
NOTICE: Only states specified in `mapper` will be saved! You can use
|
|
61
|
+
`d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will save every
|
|
62
|
+
model state without changing it.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
dest_dir: The directory to save .safetensors shards and index.
|
|
66
|
+
mapper: Topology defining how model keys map to disk keys.
|
|
67
|
+
model: The PyTorch module to save.
|
|
68
|
+
shard_size_gb: Max size per shard file in Gigabytes.
|
|
69
|
+
show_progress: Whether to display a progress bar.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
write_model_state_local(
|
|
73
|
+
dest_dir=dest_dir,
|
|
74
|
+
mapper=_augment_mapper_for_extraction([model], mapper),
|
|
75
|
+
state_generator=_state_generator([model]),
|
|
76
|
+
shard_size_gb=shard_size_gb,
|
|
77
|
+
show_progress=show_progress
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def save_model_state_pipeline_parallel(
|
|
82
|
+
dest_dir: Path,
|
|
83
|
+
mapper: ModelStateMapper,
|
|
84
|
+
device_mesh: DeviceMesh,
|
|
85
|
+
pipeline_dim_name: str,
|
|
86
|
+
models: list[nn.Module],
|
|
87
|
+
shard_size_gb: float = 4.0,
|
|
88
|
+
show_progress: bool = True
|
|
89
|
+
):
|
|
90
|
+
"""
|
|
91
|
+
High-level utility to save a model in a Distributed Pipeline Parallel environment to disk.
|
|
92
|
+
|
|
93
|
+
Features:
|
|
94
|
+
|
|
95
|
+
1. **Auto-Gather**: Converts `DTensor` parameters to full tensors before saving.
|
|
96
|
+
|
|
97
|
+
2. **Distribution Awareness**: Uses the `device_mesh` to ensure that for a given pipeline stage,
|
|
98
|
+
only the master rank writes the checkpoint, preventing Write-After-Write conflicts.
|
|
99
|
+
|
|
100
|
+
3. **Index Merging**: Aggregates metadata from all independent pipeline stages into one global index file.
|
|
101
|
+
|
|
102
|
+
NOTICE: Only states specified in `mapper` will be saved! You can use
|
|
103
|
+
`d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will save every
|
|
104
|
+
model state without changing it.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
dest_dir: directory to save .safetensors shards and index file.
|
|
108
|
+
mapper: Topology defining how model keys map to disk keys.
|
|
109
|
+
device_mesh: The cluster topology mesh.
|
|
110
|
+
pipeline_dim_name: The specific dimension name in the mesh used for pipelining.
|
|
111
|
+
models: A list of modules (pipeline stages) processed by this PP rank.
|
|
112
|
+
shard_size_gb: Max size per shard file in Gigabytes.
|
|
113
|
+
show_progress: Whether to display a progress bar.
|
|
114
|
+
"""
|
|
115
|
+
write_model_state_pipeline_parallel(
|
|
116
|
+
dest_dir=dest_dir,
|
|
117
|
+
mapper=_augment_mapper_for_extraction(models, mapper),
|
|
118
|
+
state_generator=_state_generator(models),
|
|
119
|
+
device_mesh=device_mesh,
|
|
120
|
+
pipeline_dim_name=pipeline_dim_name,
|
|
121
|
+
shard_size_gb=shard_size_gb,
|
|
122
|
+
show_progress=show_progress
|
|
123
|
+
)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Generator, Iterable
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from safetensors import safe_open
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from d9d.model_state.io.dto import MODEL_STATE_INDEX_FILE_NAME, ModelStateIndex
|
|
10
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _StateLoadingFlow:
|
|
14
|
+
"""
|
|
15
|
+
Internal orchestration logic for loading and transforming model states in a streamed manner.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
src_dir: Path,
|
|
21
|
+
mapper: ModelStateMapper,
|
|
22
|
+
device: str,
|
|
23
|
+
show_progress: bool
|
|
24
|
+
):
|
|
25
|
+
self._src_dir = src_dir
|
|
26
|
+
self._mapper = mapper
|
|
27
|
+
self._device = device
|
|
28
|
+
|
|
29
|
+
# I/O in constructor!
|
|
30
|
+
self._index = self._load_index()
|
|
31
|
+
self._groups_to_process = set(mapper.state_dependency_groups())
|
|
32
|
+
|
|
33
|
+
self._stored_states: dict[str, torch.Tensor] = {}
|
|
34
|
+
|
|
35
|
+
self._check_index()
|
|
36
|
+
|
|
37
|
+
self._pbar = tqdm(
|
|
38
|
+
desc="Loading Model States",
|
|
39
|
+
total=len([output_name for group in self._groups_to_process for output_name in group.outputs]),
|
|
40
|
+
disable=not show_progress
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _load_index(self) -> ModelStateIndex:
|
|
44
|
+
index_file = self._src_dir / MODEL_STATE_INDEX_FILE_NAME
|
|
45
|
+
index_data = index_file.read_text(encoding="utf-8")
|
|
46
|
+
index = ModelStateIndex.model_validate_json(index_data)
|
|
47
|
+
return index
|
|
48
|
+
|
|
49
|
+
def _check_index(self):
|
|
50
|
+
will_process_inputs: set[str] = set()
|
|
51
|
+
for group in self._groups_to_process:
|
|
52
|
+
will_process_inputs.update(group.inputs)
|
|
53
|
+
|
|
54
|
+
on_disk_inputs = set(self._index.weight_map.keys())
|
|
55
|
+
|
|
56
|
+
missing_inputs = will_process_inputs.difference(on_disk_inputs)
|
|
57
|
+
|
|
58
|
+
if len(missing_inputs) > 0:
|
|
59
|
+
raise ValueError(f"Cannot run state loading: states {missing_inputs} are missing!")
|
|
60
|
+
|
|
61
|
+
def _update_in_memory_states(self, file_to_load: str, params_to_load: set[str]):
|
|
62
|
+
with safe_open(str(self._src_dir / file_to_load), framework="pt", device=str(self._device)) as st:
|
|
63
|
+
for param_to_load in params_to_load:
|
|
64
|
+
self._stored_states[param_to_load] = st.get_tensor(param_to_load)
|
|
65
|
+
|
|
66
|
+
def _process_available_groups(self) -> Generator[tuple[str, torch.Tensor], None, None]:
|
|
67
|
+
for group in self._groups_to_process.copy():
|
|
68
|
+
if not group.inputs.issubset(self._stored_states.keys()):
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
self._groups_to_process.remove(group)
|
|
72
|
+
|
|
73
|
+
loaded_states = self._mapper.apply(
|
|
74
|
+
{k: v for k, v in self._stored_states.items() if k in group.inputs}
|
|
75
|
+
)
|
|
76
|
+
yield from loaded_states.items()
|
|
77
|
+
self._pbar.update(len(loaded_states))
|
|
78
|
+
|
|
79
|
+
for input_name in group.inputs:
|
|
80
|
+
del self._stored_states[input_name]
|
|
81
|
+
|
|
82
|
+
def _build_file_loading_plan(self) -> dict[str, set[str]]:
|
|
83
|
+
plan = defaultdict(set)
|
|
84
|
+
for group in self._mapper.state_dependency_groups():
|
|
85
|
+
for key in group.inputs:
|
|
86
|
+
require_file = self._index.weight_map[key]
|
|
87
|
+
plan[require_file].add(key)
|
|
88
|
+
return plan
|
|
89
|
+
|
|
90
|
+
def load(self) -> Iterable[tuple[str, torch.Tensor]]:
|
|
91
|
+
with self._pbar:
|
|
92
|
+
for file_to_load, params_to_load in self._build_file_loading_plan().items():
|
|
93
|
+
self._update_in_memory_states(file_to_load, params_to_load)
|
|
94
|
+
yield from self._process_available_groups()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def read_model_state(
|
|
98
|
+
src_dir: Path,
|
|
99
|
+
mapper: ModelStateMapper,
|
|
100
|
+
device: str,
|
|
101
|
+
show_progress: bool = True
|
|
102
|
+
) -> Iterable[tuple[str, torch.Tensor]]:
|
|
103
|
+
"""
|
|
104
|
+
Reads a model checkpoint from disk, transforming it on-the-fly according to the state mapper.
|
|
105
|
+
|
|
106
|
+
This function uses a streaming approach. It analyzes the mapper to determine which files
|
|
107
|
+
need to be loaded. Tensors are loaded into memory only when needed and evicted immediately
|
|
108
|
+
after the mapper processes them.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
src_dir: The directory containing .safetensors files and `model.safetensors.index.json` file.
|
|
112
|
+
mapper: The transformation graph defining how to map on-disk keys to output keys.
|
|
113
|
+
device: The device to load tensors onto (e.g., "cpu", "cuda:0").
|
|
114
|
+
show_progress: Whether to display a progress bar.
|
|
115
|
+
|
|
116
|
+
Yields:
|
|
117
|
+
A tuple containing the transformed parameter name and its tensor value.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
yield from _StateLoadingFlow(
|
|
121
|
+
src_dir=src_dir,
|
|
122
|
+
device=device,
|
|
123
|
+
mapper=mapper,
|
|
124
|
+
show_progress=show_progress
|
|
125
|
+
).load()
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from safetensors.torch import save_file
|
|
8
|
+
from torch.distributed import DeviceMesh, ProcessGroup
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from d9d.core.dist_ops import all_gather_object
|
|
12
|
+
from d9d.model_state.io.dto import (
|
|
13
|
+
MODEL_STATE_INDEX_FILE_NAME,
|
|
14
|
+
ModelStateIndex,
|
|
15
|
+
ModelStateIndexMeta,
|
|
16
|
+
)
|
|
17
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _StateWritingFlowLocal:
|
|
21
|
+
"""
|
|
22
|
+
Internal orchestration logic for buffering, transforming, and sharding model states during save.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
dest_dir: Path,
|
|
28
|
+
mapper: ModelStateMapper,
|
|
29
|
+
shard_size_gb: float,
|
|
30
|
+
show_progress: bool,
|
|
31
|
+
sharding_rank: int,
|
|
32
|
+
# so we have to call writing flow from all processes, but
|
|
33
|
+
is_current_process_rank_master: bool
|
|
34
|
+
):
|
|
35
|
+
self._dest_dir = dest_dir
|
|
36
|
+
self._mapper = mapper
|
|
37
|
+
self._shard_size_bytes = int(shard_size_gb * (1024 ** 3))
|
|
38
|
+
|
|
39
|
+
self._groups_to_process = set(mapper.state_dependency_groups())
|
|
40
|
+
|
|
41
|
+
self._available_source_states: dict[str, torch.Tensor] = {}
|
|
42
|
+
|
|
43
|
+
self._total_size = 0
|
|
44
|
+
self._pending_write_tensors: dict[str, torch.Tensor] = {}
|
|
45
|
+
self._current_shard_size = 0
|
|
46
|
+
|
|
47
|
+
self._sharding_rank = sharding_rank
|
|
48
|
+
self._weight_name_to_local_shard_idx: dict[str, int] = {}
|
|
49
|
+
self._local_shard_idx_to_tmp_path: dict[int, Path] = {}
|
|
50
|
+
|
|
51
|
+
self._is_current_process_rank_master = is_current_process_rank_master
|
|
52
|
+
total_num_outputs = len([out_name for group in self._groups_to_process for out_name in group.outputs])
|
|
53
|
+
self._pbar = tqdm(
|
|
54
|
+
desc="Saving Model States",
|
|
55
|
+
total=total_num_outputs,
|
|
56
|
+
disable=not (show_progress and is_current_process_rank_master)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def _flush_shard(self):
|
|
60
|
+
if not self._pending_write_tensors:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
local_shard_num = len(self._local_shard_idx_to_tmp_path) + 1
|
|
64
|
+
shard_tmp_path = self._dest_dir / f".tmp-rank{self._sharding_rank}-shard-{local_shard_num}.safetensors"
|
|
65
|
+
|
|
66
|
+
self._local_shard_idx_to_tmp_path[local_shard_num] = shard_tmp_path
|
|
67
|
+
save_file(self._pending_write_tensors, str(shard_tmp_path))
|
|
68
|
+
|
|
69
|
+
for state_name in self._pending_write_tensors:
|
|
70
|
+
self._weight_name_to_local_shard_idx[state_name] = local_shard_num
|
|
71
|
+
|
|
72
|
+
self._pbar.update(len(self._pending_write_tensors))
|
|
73
|
+
|
|
74
|
+
self._total_size += self._current_shard_size
|
|
75
|
+
|
|
76
|
+
self._pending_write_tensors.clear()
|
|
77
|
+
self._current_shard_size = 0
|
|
78
|
+
|
|
79
|
+
def _process_available_groups(self):
|
|
80
|
+
for group in self._groups_to_process.copy():
|
|
81
|
+
if not group.inputs.issubset(self._available_source_states.keys()):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
self._groups_to_process.remove(group)
|
|
85
|
+
|
|
86
|
+
states_to_save = self._mapper.apply(
|
|
87
|
+
{k: self._available_source_states[k] for k in group.inputs}
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
for input_name in group.inputs:
|
|
91
|
+
del self._available_source_states[input_name]
|
|
92
|
+
|
|
93
|
+
# proceed with stateful saving only on master rank
|
|
94
|
+
if self._is_current_process_rank_master:
|
|
95
|
+
for name, tensor in states_to_save.items():
|
|
96
|
+
update_size = tensor.numel() * tensor.element_size()
|
|
97
|
+
|
|
98
|
+
if update_size > self._shard_size_bytes:
|
|
99
|
+
raise ValueError(f"Cannot save state {name} that is larger than shard size")
|
|
100
|
+
|
|
101
|
+
if self._current_shard_size + update_size > self._shard_size_bytes:
|
|
102
|
+
self._flush_shard()
|
|
103
|
+
|
|
104
|
+
self._pending_write_tensors[name] = tensor
|
|
105
|
+
self._current_shard_size += update_size
|
|
106
|
+
|
|
107
|
+
def _finalize_locally(self) -> ModelStateIndex:
|
|
108
|
+
self._flush_shard()
|
|
109
|
+
|
|
110
|
+
if self._groups_to_process:
|
|
111
|
+
missing_groups = {g.inputs for g in self._groups_to_process}
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Writing failed: not all source tensors were provided to satisfy mapper dependencies. "
|
|
114
|
+
f"Missing inputs for groups: {missing_groups}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if self._available_source_states:
|
|
118
|
+
warnings.warn(
|
|
119
|
+
f"State Writing: The following source tensors were provided but not consumed by any "
|
|
120
|
+
f"mapper group and will be ignored: {sorted(self._available_source_states.keys())}",
|
|
121
|
+
stacklevel=2
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
weight_map_local = {
|
|
125
|
+
name: self._local_shard_idx_to_tmp_path[shard_idx].name
|
|
126
|
+
for name, shard_idx in self._weight_name_to_local_shard_idx.items()
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
return ModelStateIndex(
|
|
130
|
+
metadata=ModelStateIndexMeta(total_size=self._total_size),
|
|
131
|
+
weight_map=weight_map_local
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def write(self, state_generator: Iterable[tuple[str, torch.Tensor]]) -> ModelStateIndex | None:
|
|
135
|
+
with self._pbar:
|
|
136
|
+
self._dest_dir.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
|
|
138
|
+
for name, tensor in state_generator:
|
|
139
|
+
self._available_source_states[name] = tensor
|
|
140
|
+
self._process_available_groups()
|
|
141
|
+
|
|
142
|
+
if self._is_current_process_rank_master:
|
|
143
|
+
return self._finalize_locally()
|
|
144
|
+
else:
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _finalize_master(dest_dir: Path, indices: list[ModelStateIndex]):
|
|
149
|
+
total_size = sum(index.metadata.total_size for index in indices)
|
|
150
|
+
total_weight_map_local = dict(pair for index in indices for pair in index.weight_map.items())
|
|
151
|
+
shard_count = len({file_name for index in indices for _, file_name in index.weight_map.items()})
|
|
152
|
+
|
|
153
|
+
total_weight_map = {}
|
|
154
|
+
|
|
155
|
+
local_file_to_global_file = {}
|
|
156
|
+
used_global_files = 0
|
|
157
|
+
|
|
158
|
+
for weight_name, old_file_name in total_weight_map_local.items():
|
|
159
|
+
if old_file_name not in local_file_to_global_file:
|
|
160
|
+
used_global_files += 1
|
|
161
|
+
new_file_name = f"model-{used_global_files:05d}-of-{shard_count:05d}.safetensors"
|
|
162
|
+
|
|
163
|
+
(dest_dir / old_file_name).rename(dest_dir / new_file_name)
|
|
164
|
+
|
|
165
|
+
local_file_to_global_file[old_file_name] = new_file_name
|
|
166
|
+
|
|
167
|
+
total_weight_map[weight_name] = local_file_to_global_file[old_file_name]
|
|
168
|
+
|
|
169
|
+
index_path = dest_dir / MODEL_STATE_INDEX_FILE_NAME
|
|
170
|
+
index_path.write_text(
|
|
171
|
+
ModelStateIndex(
|
|
172
|
+
metadata=ModelStateIndexMeta(total_size=total_size),
|
|
173
|
+
weight_map=total_weight_map
|
|
174
|
+
).model_dump_json(indent=4),
|
|
175
|
+
encoding="utf-8"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def write_model_state_local(
|
|
180
|
+
dest_dir: Path,
|
|
181
|
+
mapper: ModelStateMapper,
|
|
182
|
+
state_generator: Iterable[tuple[str, torch.Tensor]],
|
|
183
|
+
shard_size_gb: float = 4.0,
|
|
184
|
+
show_progress: bool = True
|
|
185
|
+
):
|
|
186
|
+
"""
|
|
187
|
+
Saves model states to disk in a single local process.
|
|
188
|
+
|
|
189
|
+
This function uses a streaming approach. It analyzes the mapper to determine which files
|
|
190
|
+
need to be saved. Tensors are loaded into memory only when needed and evicted immediately
|
|
191
|
+
after the mapper processes them.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
dest_dir: Destination directory.
|
|
195
|
+
mapper: Mapping to apply to states before saving.
|
|
196
|
+
state_generator: Stream of (name, tensor) pairs to save.
|
|
197
|
+
shard_size_gb: Maximum size of a single .safetensors file in GB.
|
|
198
|
+
show_progress: Whether to show the progress bar.
|
|
199
|
+
"""
|
|
200
|
+
idx = _StateWritingFlowLocal(
|
|
201
|
+
dest_dir=dest_dir,
|
|
202
|
+
mapper=mapper,
|
|
203
|
+
shard_size_gb=shard_size_gb,
|
|
204
|
+
show_progress=show_progress,
|
|
205
|
+
sharding_rank=0,
|
|
206
|
+
is_current_process_rank_master=True
|
|
207
|
+
).write(state_generator=state_generator)
|
|
208
|
+
|
|
209
|
+
idx = cast(ModelStateIndex, idx) # we are sure is_current_process_rank_master=True
|
|
210
|
+
|
|
211
|
+
_finalize_master(dest_dir, [idx])
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def write_model_state_distributed(
|
|
215
|
+
dest_dir: Path,
|
|
216
|
+
mapper: ModelStateMapper,
|
|
217
|
+
state_generator: Iterable[tuple[str, torch.Tensor]],
|
|
218
|
+
process_group: ProcessGroup,
|
|
219
|
+
shard_size_gb: float = 4.0,
|
|
220
|
+
show_progress: bool = True
|
|
221
|
+
):
|
|
222
|
+
"""
|
|
223
|
+
Saves model states in a distributed setup (multiple processes).
|
|
224
|
+
|
|
225
|
+
This function uses a streaming approach. It analyzes the mapper to determine which files
|
|
226
|
+
need to be saved. Tensors are loaded into memory only when needed and evicted immediately
|
|
227
|
+
after the mapper processes them.
|
|
228
|
+
|
|
229
|
+
Each rank writes its own shard. Rank 0 gathers indices and finalizes the checkpoint.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
dest_dir: Destination directory.
|
|
233
|
+
mapper: Mapping to apply to states before saving.
|
|
234
|
+
state_generator: Stream of (name, tensor) pairs from the model.
|
|
235
|
+
process_group: The distributed process group.
|
|
236
|
+
shard_size_gb: Maximum shard size in GB.
|
|
237
|
+
show_progress: Whether to show the progress bar.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
current_idx = _StateWritingFlowLocal(
|
|
241
|
+
dest_dir=dest_dir,
|
|
242
|
+
mapper=mapper,
|
|
243
|
+
shard_size_gb=shard_size_gb,
|
|
244
|
+
show_progress=show_progress,
|
|
245
|
+
sharding_rank=process_group.rank(),
|
|
246
|
+
is_current_process_rank_master=True
|
|
247
|
+
).write(state_generator=state_generator)
|
|
248
|
+
gather_idx = all_gather_object(current_idx, process_group)
|
|
249
|
+
gather_idx_filter = [x for x in gather_idx if x is not None]
|
|
250
|
+
if process_group.rank() == 0:
|
|
251
|
+
_finalize_master(dest_dir, gather_idx_filter)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def write_model_state_pipeline_parallel(
|
|
255
|
+
dest_dir: Path,
|
|
256
|
+
mapper: ModelStateMapper,
|
|
257
|
+
state_generator: Iterable[tuple[str, torch.Tensor]],
|
|
258
|
+
device_mesh: DeviceMesh,
|
|
259
|
+
pipeline_dim_name: str,
|
|
260
|
+
shard_size_gb: float = 4.0,
|
|
261
|
+
show_progress: bool = True
|
|
262
|
+
):
|
|
263
|
+
"""
|
|
264
|
+
Saves model states in a complex ND distributed training setting.
|
|
265
|
+
|
|
266
|
+
This function uses a streaming approach. It analyzes the mapper to determine which files
|
|
267
|
+
need to be saved. Tensors are loaded into memory only when needed and evicted immediately
|
|
268
|
+
after the mapper processes them.
|
|
269
|
+
|
|
270
|
+
This handles Pipeline Parallelism by ensuring that only one rank per pipeline stage
|
|
271
|
+
actually writes data to disk to avoid duplication.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
dest_dir: Destination directory.
|
|
275
|
+
mapper: Mapping to apply to states before saving.
|
|
276
|
+
state_generator: Stream of (name, tensor) pairs from the model.
|
|
277
|
+
device_mesh: The PyTorch DeviceMesh representing the cluster layout.
|
|
278
|
+
pipeline_dim_name: The name of the mesh dimension responsible for pipeline parallelism.
|
|
279
|
+
shard_size_gb: Maximum shard size in GB.
|
|
280
|
+
show_progress: Whether to show the progress bar.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
pipeline_rank = device_mesh[pipeline_dim_name].get_rank()
|
|
284
|
+
|
|
285
|
+
mesh_dim_names = device_mesh.mesh_dim_names
|
|
286
|
+
coords = device_mesh.get_coordinate()
|
|
287
|
+
if mesh_dim_names is None or coords is None:
|
|
288
|
+
raise ValueError("Cannot save state using a DeviceMesh with no dim names or coords")
|
|
289
|
+
|
|
290
|
+
non_pipeline_coord_sum = sum(
|
|
291
|
+
coord
|
|
292
|
+
for name, coord
|
|
293
|
+
in zip(mesh_dim_names, coords, strict=True)
|
|
294
|
+
if name != pipeline_dim_name
|
|
295
|
+
)
|
|
296
|
+
master_within_pipeline_rank = non_pipeline_coord_sum == 0
|
|
297
|
+
|
|
298
|
+
current_idx = _StateWritingFlowLocal(
|
|
299
|
+
dest_dir=dest_dir,
|
|
300
|
+
mapper=mapper,
|
|
301
|
+
shard_size_gb=shard_size_gb,
|
|
302
|
+
show_progress=show_progress,
|
|
303
|
+
sharding_rank=pipeline_rank,
|
|
304
|
+
is_current_process_rank_master=master_within_pipeline_rank
|
|
305
|
+
).write(state_generator=state_generator)
|
|
306
|
+
gather_idx = all_gather_object(current_idx, device_mesh.get_group(0))
|
|
307
|
+
gather_idx_filter = [x for x in gather_idx if x is not None]
|
|
308
|
+
if pipeline_rank == 0 and master_within_pipeline_rank:
|
|
309
|
+
_finalize_master(dest_dir, gather_idx_filter)
|