d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import TypeVar, cast
|
|
2
|
+
|
|
3
|
+
import torch.distributed as dist
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def gather_object(
|
|
9
|
+
obj: T,
|
|
10
|
+
group: dist.ProcessGroup,
|
|
11
|
+
group_dst: int
|
|
12
|
+
) -> list[T] | None:
|
|
13
|
+
"""
|
|
14
|
+
Gathers picklable objects from the whole process group to a specific destination rank.
|
|
15
|
+
|
|
16
|
+
This acts as a wrapper around torch.distributed.gather_object that automatically
|
|
17
|
+
initializes the output buffer list on the destination rank.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
obj: The local object to send. Must be picklable.
|
|
21
|
+
group: The process group to work on.
|
|
22
|
+
group_dst: The rank within the group that will receive the objects.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
A list of objects from all ranks on the destination rank; None on other ranks.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
if group.rank() == group_dst:
|
|
29
|
+
# We initialize with None, but we cast to list[T] because we know
|
|
30
|
+
# dist.gather_object will populate these slots with actual objects.
|
|
31
|
+
save_list = cast(list[T], [None for _ in range(group.size())])
|
|
32
|
+
else:
|
|
33
|
+
save_list = None
|
|
34
|
+
dist.gather_object(
|
|
35
|
+
obj,
|
|
36
|
+
save_list,
|
|
37
|
+
group=group,
|
|
38
|
+
group_dst=group_dst
|
|
39
|
+
)
|
|
40
|
+
return save_list
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def all_gather_object(
|
|
44
|
+
obj: T,
|
|
45
|
+
group: dist.ProcessGroup
|
|
46
|
+
) -> list[T]:
|
|
47
|
+
"""
|
|
48
|
+
Gathers picklable objects from the whole process group to all ranks.
|
|
49
|
+
|
|
50
|
+
This acts as a wrapper around torch.distributed.all_gather_object that automatically
|
|
51
|
+
initializes the output buffer list on all ranks.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
obj: The local object to send. Must be picklable.
|
|
55
|
+
group: The process group to work on.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A list of objects containing the data gathered from all ranks.
|
|
59
|
+
"""
|
|
60
|
+
# We initialize with None, but we cast to list[T] because we know
|
|
61
|
+
# dist.gather_object will populate these slots with actual objects.
|
|
62
|
+
save_list = cast(list[T], [None for _ in range(group.size())])
|
|
63
|
+
dist.all_gather_object(
|
|
64
|
+
save_list,
|
|
65
|
+
obj,
|
|
66
|
+
group=group
|
|
67
|
+
)
|
|
68
|
+
return save_list
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def gather(
|
|
9
|
+
tensor: torch.Tensor,
|
|
10
|
+
group: dist.ProcessGroup,
|
|
11
|
+
group_dst: int,
|
|
12
|
+
async_op: bool = False
|
|
13
|
+
) -> list[torch.Tensor] | tuple[list[torch.Tensor] | None, dist.Work] | None:
|
|
14
|
+
"""
|
|
15
|
+
Gathers tensors from the process group to a specific destination rank.
|
|
16
|
+
|
|
17
|
+
This function assumes that tensors on all ranks have the same shape and dtype
|
|
18
|
+
as the tensor on the current rank. It automatically allocates the output
|
|
19
|
+
buffer list on the destination.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
tensor: The local tensor to send.
|
|
23
|
+
group: The process group to work on.
|
|
24
|
+
group_dst: The rank within the group that will receive the tensors.
|
|
25
|
+
async_op: Whether the operation should be asynchronous.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
If async_op is False: A list of tensors on the destination rank, None elsewhere.
|
|
29
|
+
If async_op is True: A tuple containing (buffer_list, work_handle).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
if group.rank() == group_dst:
|
|
33
|
+
save_list = [torch.empty_like(tensor) for _ in range(group.size())]
|
|
34
|
+
else:
|
|
35
|
+
save_list = None
|
|
36
|
+
|
|
37
|
+
work = dist.gather(
|
|
38
|
+
tensor,
|
|
39
|
+
save_list,
|
|
40
|
+
group=group,
|
|
41
|
+
group_dst=group_dst,
|
|
42
|
+
async_op=async_op
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if async_op:
|
|
46
|
+
return save_list, work
|
|
47
|
+
else:
|
|
48
|
+
return save_list
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def all_gather(
|
|
52
|
+
tensor: torch.Tensor,
|
|
53
|
+
group: dist.ProcessGroup,
|
|
54
|
+
async_op: bool = False
|
|
55
|
+
) -> list[torch.Tensor] | tuple[list[torch.Tensor], dist.Work]:
|
|
56
|
+
"""
|
|
57
|
+
Gathers tensors from the whole process group to all ranks.
|
|
58
|
+
|
|
59
|
+
This function assumes that tensors on all ranks have the same shape and dtype
|
|
60
|
+
as the tensor on the current rank. It automatically allocates the output
|
|
61
|
+
buffer list.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
tensor: The local tensor to send.
|
|
65
|
+
group: The process group to work on.
|
|
66
|
+
async_op: Whether the operation should be asynchronous.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
If async_op is False: A list of gathered tensors.
|
|
70
|
+
If async_op is True: A tuple containing (buffer_list, work_handle).
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
save_list = [torch.empty_like(tensor) for _ in range(group.size())]
|
|
74
|
+
work = dist.all_gather(
|
|
75
|
+
save_list,
|
|
76
|
+
tensor,
|
|
77
|
+
group=group,
|
|
78
|
+
async_op=async_op
|
|
79
|
+
)
|
|
80
|
+
if async_op:
|
|
81
|
+
return save_list, work
|
|
82
|
+
else:
|
|
83
|
+
return save_list
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _all_gather_shapes(
|
|
87
|
+
tensor: torch.Tensor,
|
|
88
|
+
group: dist.ProcessGroup,
|
|
89
|
+
) -> Sequence[torch.Tensor]:
|
|
90
|
+
all_ndim = [torch.empty((), dtype=torch.long, device=tensor.device) for _ in range(group.size())]
|
|
91
|
+
all_ndim_wait = dist.all_gather(
|
|
92
|
+
all_ndim,
|
|
93
|
+
torch.tensor(tensor.ndim, dtype=torch.long, device=tensor.device),
|
|
94
|
+
group=group,
|
|
95
|
+
async_op=True
|
|
96
|
+
)
|
|
97
|
+
all_ndim_wait.wait()
|
|
98
|
+
|
|
99
|
+
all_shape = [torch.empty(cast(int, ndim.item()), dtype=torch.long, device=tensor.device) for ndim in all_ndim]
|
|
100
|
+
all_shape_wait = dist.all_gather(
|
|
101
|
+
all_shape,
|
|
102
|
+
torch.tensor(tensor.shape, dtype=torch.long, device=tensor.device),
|
|
103
|
+
group=group,
|
|
104
|
+
async_op=True
|
|
105
|
+
)
|
|
106
|
+
all_shape_wait.wait()
|
|
107
|
+
|
|
108
|
+
return all_shape
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def all_gather_variadic_shape(
|
|
112
|
+
tensor: torch.Tensor,
|
|
113
|
+
group: dist.ProcessGroup,
|
|
114
|
+
async_op: bool = False
|
|
115
|
+
) -> list[torch.Tensor] | tuple[list[torch.Tensor], dist.Work]:
|
|
116
|
+
"""
|
|
117
|
+
Gathers tensors of different shapes from the whole process group to all ranks.
|
|
118
|
+
|
|
119
|
+
Unlike standard all_gather, this function first communicates the shape of the
|
|
120
|
+
tensor on every rank allowing for dynamic sizing.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
tensor: The local tensor to send.
|
|
124
|
+
group: The process group to work on.
|
|
125
|
+
async_op: Whether the final data gathering should be asynchronous.
|
|
126
|
+
Note that shape gathering is always synchronous.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
If async_op is False: A list of gathered tensors of varying shapes.
|
|
130
|
+
If async_op is True: A tuple containing (buffer_list, work_handle).
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
all_shape = _all_gather_shapes(tensor, group)
|
|
134
|
+
|
|
135
|
+
all_result = [torch.empty(tuple(shape), dtype=tensor.dtype, device=tensor.device) for shape in all_shape]
|
|
136
|
+
all_result_wait = dist.all_gather(
|
|
137
|
+
all_result,
|
|
138
|
+
tensor,
|
|
139
|
+
group=group,
|
|
140
|
+
async_op=async_op
|
|
141
|
+
)
|
|
142
|
+
if async_op:
|
|
143
|
+
return all_result, all_result_wait
|
|
144
|
+
else:
|
|
145
|
+
return all_result
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def gather_variadic_shape(
|
|
149
|
+
tensor: torch.Tensor,
|
|
150
|
+
group: dist.ProcessGroup,
|
|
151
|
+
group_dst: int
|
|
152
|
+
) -> list[torch.Tensor] | None:
|
|
153
|
+
"""
|
|
154
|
+
Gathers tensors of different shapes from the process group to a specific rank.
|
|
155
|
+
|
|
156
|
+
This function coordinates shape exchange and uses point-to-point communication
|
|
157
|
+
(isend/irecv) to gather tensors that may differ in shape across ranks.
|
|
158
|
+
|
|
159
|
+
Currently, does not support async_op.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
tensor: The local tensor to send.
|
|
163
|
+
group: The process group to work on.
|
|
164
|
+
group_dst: The rank within the group that will receive the tensors.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
A list of tensors of varying shapes on the destination rank; None on other ranks.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
is_current_dst = group.rank() == group_dst
|
|
171
|
+
|
|
172
|
+
all_shape = _all_gather_shapes(tensor, group)
|
|
173
|
+
|
|
174
|
+
if is_current_dst:
|
|
175
|
+
all_recv_futures: list[dist.Work] = []
|
|
176
|
+
all_result: list[torch.Tensor] = cast(list[torch.Tensor], [None for _ in range(group.size())])
|
|
177
|
+
for group_src_i in range(group.size()):
|
|
178
|
+
if group_src_i == group_dst:
|
|
179
|
+
all_result[group_src_i] = tensor
|
|
180
|
+
continue
|
|
181
|
+
all_result[group_src_i] = torch.empty(
|
|
182
|
+
tuple(all_shape[group_src_i]), dtype=tensor.dtype, device=tensor.device
|
|
183
|
+
)
|
|
184
|
+
all_recv_future = dist.irecv(all_result[group_src_i], group=group, group_src=group_src_i)
|
|
185
|
+
all_recv_future = cast(dist.Work, all_recv_future) # we know we are on dst rank
|
|
186
|
+
all_recv_futures.append(all_recv_future)
|
|
187
|
+
for recv_future in all_recv_futures:
|
|
188
|
+
recv_future.wait()
|
|
189
|
+
return all_result
|
|
190
|
+
else:
|
|
191
|
+
dist.isend(tensor=tensor, group=group, group_dst=group_dst)
|
|
192
|
+
return None
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Protocol, runtime_checkable
|
|
2
|
+
|
|
3
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@runtime_checkable
|
|
7
|
+
class OptimizerProtocol(Protocol, Stateful):
|
|
8
|
+
"""
|
|
9
|
+
Protocol defining an interface for standard PyTorch Optimizer object.
|
|
10
|
+
|
|
11
|
+
This protocol ensures that the wrapped optimizer supports standard
|
|
12
|
+
API and state checkpointing via the Stateful interface.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def step(self):
|
|
16
|
+
"""Performs a single optimization step."""
|
|
17
|
+
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def zero_grad(self):
|
|
21
|
+
"""Sets the gradients of all optimized tensors to zero."""
|
|
22
|
+
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@runtime_checkable
|
|
27
|
+
class LRSchedulerProtocol(Protocol, Stateful):
|
|
28
|
+
"""
|
|
29
|
+
Protocol defining an interface for a Learning Rate Scheduler.
|
|
30
|
+
|
|
31
|
+
This protocol ensures that the wrapped scheduler supports stepping
|
|
32
|
+
and state checkpointing via the Stateful interface.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def step(self):
|
|
36
|
+
"""Performs a single learning rate scheduling step."""
|
|
37
|
+
|
|
38
|
+
...
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .auto_spec import shard_spec_nothing, shard_spec_on_dim
|
|
2
|
+
from .shard import shard_tree
|
|
3
|
+
from .spec import ShardingSpec, ShardingSpecLeaf, SpecReplicate, SpecShard
|
|
4
|
+
from .unshard import unshard_tree
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ShardingSpec",
|
|
8
|
+
"ShardingSpecLeaf",
|
|
9
|
+
"SpecReplicate",
|
|
10
|
+
"SpecShard",
|
|
11
|
+
"shard_spec_nothing",
|
|
12
|
+
"shard_spec_on_dim",
|
|
13
|
+
"shard_tree",
|
|
14
|
+
"unshard_tree"
|
|
15
|
+
]
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.utils._pytree as pytree # noqa: PLC2701
|
|
5
|
+
|
|
6
|
+
from d9d.core.types import PyTree
|
|
7
|
+
|
|
8
|
+
from .spec import ShardingSpec, ShardingSpecLeaf, SpecReplicate, SpecShard
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _tree_item_to_shard(item: Any, shard_on_dim: int) -> ShardingSpecLeaf:
|
|
12
|
+
if isinstance(item, list):
|
|
13
|
+
if shard_on_dim != 0:
|
|
14
|
+
raise ValueError(f"Cannot shard list on dim {shard_on_dim}. Lists behave as 1D sequences.")
|
|
15
|
+
return SpecShard(0)
|
|
16
|
+
elif isinstance(item, torch.Tensor):
|
|
17
|
+
if item.ndim == 0:
|
|
18
|
+
return SpecReplicate()
|
|
19
|
+
if item.ndim <= shard_on_dim:
|
|
20
|
+
raise ValueError(f"Cannot shard {item.ndim}-dimensional tensor on dim {shard_on_dim}")
|
|
21
|
+
return SpecShard(shard_on_dim)
|
|
22
|
+
else:
|
|
23
|
+
return SpecReplicate()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def shard_spec_on_dim(tree: PyTree[Any], dim: int) -> ShardingSpec:
|
|
27
|
+
"""
|
|
28
|
+
Creates a sharding specification to split all tensors in the tree on a specific dimension.
|
|
29
|
+
|
|
30
|
+
Iterates over the input tree:
|
|
31
|
+
* If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object.
|
|
32
|
+
* If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists).
|
|
33
|
+
* Other types and 0-dim tensors are mapped to SpecReplicate.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
tree: The input PyTree structure.
|
|
37
|
+
dim: The dimension index to shard eligible tensors on.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If a tensor exists in the tree with rank less than or equal to 'dim'.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
return pytree.tree_map(
|
|
47
|
+
lambda x: _tree_item_to_shard(x, dim),
|
|
48
|
+
tree,
|
|
49
|
+
is_leaf=lambda x: isinstance(x, (torch.Tensor, list))
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def shard_spec_nothing(tree: PyTree[Any]) -> ShardingSpec:
|
|
54
|
+
"""
|
|
55
|
+
Creates a sharding specification where no sharding is performed.
|
|
56
|
+
|
|
57
|
+
This effectively clones the tree structure but replaces every leaf with SpecReplicate.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
tree: The input PyTree structure.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
return pytree.tree_map(lambda _: SpecReplicate(), tree, is_leaf=lambda x: isinstance(x, (torch.Tensor, list)))
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import TypeVar, cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils._pytree as pytree # noqa: PLC2701
|
|
6
|
+
|
|
7
|
+
from d9d.core.types import PyTree
|
|
8
|
+
|
|
9
|
+
from .spec import ShardingSpec, SpecReplicate, SpecShard
|
|
10
|
+
|
|
11
|
+
TLeaf = TypeVar("TLeaf")
|
|
12
|
+
TSameTree = TypeVar("TSameTree", bound=PyTree)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _shard_list(
|
|
16
|
+
item: list[TLeaf],
|
|
17
|
+
spec: SpecShard,
|
|
18
|
+
num_shards: int,
|
|
19
|
+
enforce_even_split: bool
|
|
20
|
+
) -> Sequence[list[TLeaf] | TLeaf]:
|
|
21
|
+
if spec.dim != 0:
|
|
22
|
+
raise ValueError(f"Lists can only be sharded on dim 0, got {spec.dim}")
|
|
23
|
+
|
|
24
|
+
if spec.do_stack:
|
|
25
|
+
if len(item) != num_shards:
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"do_stack=True requires list length ({len(item)}) to match num_shards ({num_shards})"
|
|
28
|
+
)
|
|
29
|
+
return item
|
|
30
|
+
|
|
31
|
+
if enforce_even_split and len(item) % num_shards != 0:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Tried to shard a list with length {len(item)} "
|
|
34
|
+
f"to {num_shards} shards, but the length is not perfectly divisible."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
shard_size, shard_extra = divmod(len(item), num_shards)
|
|
38
|
+
return [
|
|
39
|
+
item[
|
|
40
|
+
shard_id * shard_size + min(shard_id, shard_extra):
|
|
41
|
+
(shard_id + 1) * shard_size + min(shard_id + 1, shard_extra)
|
|
42
|
+
]
|
|
43
|
+
for shard_id in range(num_shards)
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _shard_tensor(
|
|
48
|
+
item: torch.Tensor,
|
|
49
|
+
spec: SpecShard,
|
|
50
|
+
num_shards: int,
|
|
51
|
+
enforce_even_split: bool
|
|
52
|
+
) -> Sequence[torch.Tensor]:
|
|
53
|
+
if item.ndim == 0:
|
|
54
|
+
raise ValueError("Found a 0-dim Tensor for sharding")
|
|
55
|
+
|
|
56
|
+
if spec.do_stack:
|
|
57
|
+
if item.shape[spec.dim] != num_shards:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"do_stack=True requires tensor shape[{spec.dim}] ({item.shape[spec.dim]}) "
|
|
60
|
+
f"to match num_shards ({num_shards})"
|
|
61
|
+
)
|
|
62
|
+
return torch.unbind(item, dim=spec.dim)
|
|
63
|
+
|
|
64
|
+
if enforce_even_split and item.shape[spec.dim] % num_shards != 0:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Tried to shard a tensor with shape {item.shape} on dim {spec.dim} "
|
|
67
|
+
f"to {num_shards} shards, but the dimension is not perfectly divisible."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return torch.tensor_split(item, sections=num_shards, dim=spec.dim)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _shard_leaf_to_list(
|
|
74
|
+
item: TLeaf,
|
|
75
|
+
spec: SpecShard | SpecReplicate,
|
|
76
|
+
num_shards: int,
|
|
77
|
+
enforce_even_split: bool
|
|
78
|
+
) -> Sequence[TLeaf]:
|
|
79
|
+
"""Helper to split an item into a list of items for each rank."""
|
|
80
|
+
if isinstance(spec, SpecReplicate):
|
|
81
|
+
# Replicated: strict copy reference for all shards
|
|
82
|
+
return [item] * num_shards
|
|
83
|
+
|
|
84
|
+
if not isinstance(spec, SpecShard):
|
|
85
|
+
raise TypeError(f"Unknown sharding spec object type: {type(spec)}")
|
|
86
|
+
|
|
87
|
+
if isinstance(item, torch.Tensor):
|
|
88
|
+
return cast(Sequence[TLeaf], _shard_tensor(
|
|
89
|
+
item=item,
|
|
90
|
+
num_shards=num_shards,
|
|
91
|
+
enforce_even_split=enforce_even_split,
|
|
92
|
+
spec=spec
|
|
93
|
+
))
|
|
94
|
+
elif isinstance(item, list):
|
|
95
|
+
return cast(Sequence[TLeaf], _shard_list(
|
|
96
|
+
item=item,
|
|
97
|
+
num_shards=num_shards,
|
|
98
|
+
enforce_even_split=enforce_even_split,
|
|
99
|
+
spec=spec
|
|
100
|
+
))
|
|
101
|
+
else:
|
|
102
|
+
raise TypeError(
|
|
103
|
+
f"Sharding spec found a SpecShard object, but the item was not a Tensor and not a list (got {type(item)})"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def shard_tree(
|
|
108
|
+
tree: TSameTree,
|
|
109
|
+
sharding_spec: ShardingSpec,
|
|
110
|
+
num_shards: int,
|
|
111
|
+
enforce_even_split: bool
|
|
112
|
+
) -> tuple[TSameTree, ...]:
|
|
113
|
+
"""
|
|
114
|
+
Shards a PyTree into a tuple of PyTrees, one for each shard rank.
|
|
115
|
+
|
|
116
|
+
This function takes a single global data structure and splits it into `num_shards`
|
|
117
|
+
structures.
|
|
118
|
+
|
|
119
|
+
* If a spec leaf is a ``SpecShard(dim)``, the tensor or list is split along that dimension,
|
|
120
|
+
and the ``i``-th slice goes to the ``i``-th output tree.
|
|
121
|
+
* If a spec leaf is ``SpecReplicate``, the item is replicated (reference copy) to all
|
|
122
|
+
output trees.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
tree: The structure containing tensors to be sharded.
|
|
126
|
+
sharding_spec: A structure matching 'tree' containing ``SpecShard`` or ``SpecReplicate`` objects.
|
|
127
|
+
num_shards: The total number of shards to split the tensors into.
|
|
128
|
+
enforce_even_split: If True, raises a ValueError if a tensor's dimension
|
|
129
|
+
size is not perfectly divisible by ``num_shards``.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
A tuple of length ``num_shards``. Each element is a PyTree matching
|
|
133
|
+
the structure of the input ``tree``, containing the local data for
|
|
134
|
+
that specific rank.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
ValueError: If tree structures do not match, or valid sharding conditions
|
|
138
|
+
are not met.
|
|
139
|
+
"""
|
|
140
|
+
flat_spec, spec_struct = pytree.tree_flatten(sharding_spec)
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
flat_tree = spec_struct.flatten_up_to(tree)
|
|
144
|
+
except (ValueError, TypeError) as e:
|
|
145
|
+
raise ValueError("Tree structure does not match sharding spec") from e
|
|
146
|
+
|
|
147
|
+
sharded_leaves_per_node = [
|
|
148
|
+
_shard_leaf_to_list(item, spec, num_shards, enforce_even_split)
|
|
149
|
+
for item, spec in zip(flat_tree, flat_spec, strict=True)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
rank_leaves = list(zip(*sharded_leaves_per_node, strict=True))
|
|
153
|
+
|
|
154
|
+
return tuple(spec_struct.unflatten(leaves) for leaves in rank_leaves)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
from d9d.core.types import PyTree
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclasses.dataclass(slots=True, frozen=True)
|
|
7
|
+
class SpecReplicate:
|
|
8
|
+
"""
|
|
9
|
+
Specifies that a leaf node should be replicated across all shards.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclasses.dataclass(slots=True, frozen=True)
|
|
14
|
+
class SpecShard:
|
|
15
|
+
"""
|
|
16
|
+
Specifies that a leaf node should be split along a specific dimension.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
dim: The dimension to split.
|
|
20
|
+
do_stack: If True, sharding will squeeze the sharded dimension (it should be exactly the num_shards length)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
dim: int
|
|
24
|
+
do_stack: bool = False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
ShardingSpecLeaf = SpecReplicate | SpecShard
|
|
28
|
+
ShardingSpec = PyTree[ShardingSpecLeaf]
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import TypeVar, cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils._pytree as pytree # noqa: PLC2701
|
|
6
|
+
|
|
7
|
+
from d9d.core.types import PyTree
|
|
8
|
+
|
|
9
|
+
from .spec import ShardingSpec, ShardingSpecLeaf, SpecReplicate, SpecShard
|
|
10
|
+
|
|
11
|
+
TLeaf = TypeVar("TLeaf")
|
|
12
|
+
TSameTree = TypeVar("TSameTree", bound=PyTree)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _unshard_list(
|
|
16
|
+
group: Sequence[list[TLeaf] | TLeaf],
|
|
17
|
+
spec: SpecShard
|
|
18
|
+
) -> list[TLeaf]:
|
|
19
|
+
if spec.dim != 0:
|
|
20
|
+
raise ValueError(f"Lists can only be unsharded on dim 0, got {spec.dim}")
|
|
21
|
+
|
|
22
|
+
if spec.do_stack:
|
|
23
|
+
return cast(list[TLeaf], list(group))
|
|
24
|
+
|
|
25
|
+
merged_list: list[TLeaf] = []
|
|
26
|
+
for x in group:
|
|
27
|
+
merged_list.extend(cast(list[TLeaf], x))
|
|
28
|
+
return merged_list
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _unshard_tensor(
|
|
32
|
+
group: list[torch.Tensor],
|
|
33
|
+
spec: SpecShard
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
if spec.do_stack:
|
|
36
|
+
return torch.stack(group, dim=spec.dim)
|
|
37
|
+
|
|
38
|
+
return torch.cat(group, dim=spec.dim)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _unshard_leaf_from_group(
|
|
42
|
+
group: Sequence[TLeaf],
|
|
43
|
+
spec: ShardingSpecLeaf
|
|
44
|
+
) -> TLeaf:
|
|
45
|
+
"""Helper to merge a group of items from different ranks into one."""
|
|
46
|
+
if isinstance(spec, SpecReplicate):
|
|
47
|
+
return group[0]
|
|
48
|
+
|
|
49
|
+
if not isinstance(spec, SpecShard):
|
|
50
|
+
raise TypeError(f"Unknown sharding spec object type: {type(spec)}")
|
|
51
|
+
|
|
52
|
+
first_item = group[0]
|
|
53
|
+
|
|
54
|
+
if isinstance(first_item, torch.Tensor):
|
|
55
|
+
return cast(TLeaf, _unshard_tensor(
|
|
56
|
+
cast(list[torch.Tensor], group),
|
|
57
|
+
spec
|
|
58
|
+
))
|
|
59
|
+
elif spec.do_stack or isinstance(first_item, list):
|
|
60
|
+
return cast(TLeaf, _unshard_list(group, spec))
|
|
61
|
+
else:
|
|
62
|
+
raise TypeError(f"Expected Tensor or list instances, got {type(group[0])}")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def unshard_tree(
|
|
66
|
+
sharded_trees: Sequence[TSameTree],
|
|
67
|
+
sharding_spec: ShardingSpec
|
|
68
|
+
) -> TSameTree:
|
|
69
|
+
"""
|
|
70
|
+
Combines a sequence of PyTrees (one per rank) into a single global PyTree.
|
|
71
|
+
|
|
72
|
+
This is the inverse of ``shard_tree``. It iterates over the provided trees,
|
|
73
|
+
gathering corresponding leaves from each rank.
|
|
74
|
+
|
|
75
|
+
* If the spec for a leaf is ``SpecShard(dim)``, the tensors from all ranks are
|
|
76
|
+
concatenated (or stacked if ``do_stack=True``) along that dimension.
|
|
77
|
+
* If the spec is ``SpecReplicate``, it assumes the data is replicated
|
|
78
|
+
and takes the item from the first rank.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
sharded_trees: A sequence (list or tuple) of PyTrees. There must be
|
|
82
|
+
one tree for each shard rank, and they must all share the same
|
|
83
|
+
structure as ``sharding_spec``.
|
|
84
|
+
sharding_spec: A structure matching the input trees containing
|
|
85
|
+
``SpecShard`` or ``SpecReplicate`` objects.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
A single PyTree where distinct shards have been merged into full tensors.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If ``sharded_trees`` is empty, or if unit structures do
|
|
92
|
+
not match the spec.
|
|
93
|
+
"""
|
|
94
|
+
if not sharded_trees:
|
|
95
|
+
raise ValueError("sharded_trees sequence cannot be empty")
|
|
96
|
+
|
|
97
|
+
flat_spec, spec_struct = pytree.tree_flatten(sharding_spec)
|
|
98
|
+
|
|
99
|
+
flat_shards_per_rank = []
|
|
100
|
+
for i, tree in enumerate(sharded_trees):
|
|
101
|
+
try:
|
|
102
|
+
leaves = spec_struct.flatten_up_to(tree)
|
|
103
|
+
except (ValueError, TypeError) as e:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Structure mismatch at shard {i}: tree does not match sharding spec structure"
|
|
106
|
+
) from e
|
|
107
|
+
|
|
108
|
+
flat_shards_per_rank.append(leaves)
|
|
109
|
+
|
|
110
|
+
grouped_leaves = list(zip(*flat_shards_per_rank, strict=True))
|
|
111
|
+
|
|
112
|
+
reconstructed_leaves = [
|
|
113
|
+
_unshard_leaf_from_group(group, spec)
|
|
114
|
+
for group, spec in zip(grouped_leaves, flat_spec, strict=True)
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
return spec_struct.unflatten(reconstructed_leaves)
|