d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass(frozen=True)
|
|
8
|
+
class StateGroup:
|
|
9
|
+
"""
|
|
10
|
+
Represents an atomic unit of dependency in the model state transformation graph.
|
|
11
|
+
|
|
12
|
+
A `StateGroup` defines a strict contract between a set of input keys (source)
|
|
13
|
+
and a set of output keys (destination).
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
inputs: The complete set of keys required from the source state dictionary to satisfy this dependency.
|
|
17
|
+
outputs: The complete set of keys that will be produced as a result of this transformation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
inputs: frozenset[str]
|
|
21
|
+
outputs: frozenset[str]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelStateMapper(abc.ABC):
|
|
25
|
+
"""
|
|
26
|
+
The abstract base class for all model state transformation operations.
|
|
27
|
+
|
|
28
|
+
This class serves as the interface between the definition of a transformation
|
|
29
|
+
topology and the actual execution of tensor operations.
|
|
30
|
+
|
|
31
|
+
It enforces a Declarative vs. Imperative separation of concerns:
|
|
32
|
+
|
|
33
|
+
1. Declarative (Topology): Through `state_dependency_groups()`, the mapper
|
|
34
|
+
announces *what* it intends to do without handling any data. This allows the system to build execution graphs,
|
|
35
|
+
validate chains, detect collisions, and shard tasks *before* allocating memory.
|
|
36
|
+
2. Imperative (Execution): Through `apply()`, the mapper performs the
|
|
37
|
+
actual logic (PyTorch operations) on model states.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
42
|
+
"""
|
|
43
|
+
Calculates and returns the set of independent dependency groups this mapper handles.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A frozenset of `StateGroup` objects. Each group
|
|
47
|
+
represents a disjoint operation. For example, a mapper that renames ten
|
|
48
|
+
independent tensors would return ten distinct `StateGroup` objects,
|
|
49
|
+
allowing them to be sharded or processed individually.
|
|
50
|
+
"""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
@abc.abstractmethod
|
|
54
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""
|
|
56
|
+
Executes the transformation logic on a specific dictionary of tensors.
|
|
57
|
+
|
|
58
|
+
The orchestration system guarantees that the `group` dictionary passed here contains
|
|
59
|
+
all keys listed in the `inputs` of the active `StateGroup`.
|
|
60
|
+
|
|
61
|
+
Implementation of this method should guarantee that the result will contain all keys listed in the `outputs`.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
group: A dictionary containing the source data.
|
|
65
|
+
Keys match `StateGroup.inputs`.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A dictionary containing the transformed data. Keys must strictly match `StateGroup.outputs`.
|
|
69
|
+
"""
|
|
70
|
+
...
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This package provides utility functions that are used to create simple ModelStateMapper instances from objects
|
|
3
|
+
such as PyTorch modules or other StateMappers
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .mapper import identity_mapper_from_mapper_outputs
|
|
7
|
+
from .module import identity_mapper_from_module
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"identity_mapper_from_mapper_outputs",
|
|
11
|
+
"identity_mapper_from_module"
|
|
12
|
+
]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
2
|
+
from d9d.model_state.mapper.compose import ModelStateMapperParallel
|
|
3
|
+
from d9d.model_state.mapper.leaf import ModelStateMapperIdentity
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def identity_mapper_from_mapper_outputs(mapper: ModelStateMapper) -> ModelStateMapper:
|
|
7
|
+
"""
|
|
8
|
+
Creates an identity mapper covering all outputs produced by the provided mapper.
|
|
9
|
+
|
|
10
|
+
This function inspects the `state_dependency_groups()` of the input `mapper`,
|
|
11
|
+
extracts every key listed in the `outputs` set of each group, and creates a
|
|
12
|
+
corresponding `ModelStateMapperIdentity` for it.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
mapper: The mapper whose output signature will be inspected to generate the new identity mapper.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A composite mapper that acts as a pass-through for every key produced by the source `mapper`.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
mappers: list[ModelStateMapper] = []
|
|
22
|
+
|
|
23
|
+
for state_group in mapper.state_dependency_groups():
|
|
24
|
+
for output_name in state_group.outputs:
|
|
25
|
+
mappers.append(ModelStateMapperIdentity(output_name))
|
|
26
|
+
|
|
27
|
+
return ModelStateMapperParallel(mappers)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
4
|
+
from d9d.model_state.mapper.compose import ModelStateMapperParallel
|
|
5
|
+
from d9d.model_state.mapper.leaf import ModelStateMapperIdentity
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def identity_mapper_from_module(module: nn.Module) -> ModelStateMapper:
|
|
9
|
+
"""
|
|
10
|
+
Creates an identity mapper for every parameter in a single PyTorch module.
|
|
11
|
+
|
|
12
|
+
It is useful when you want to define a "pass-through" pipeline where the
|
|
13
|
+
source checkpoint keys are expected to exactly match the model's current
|
|
14
|
+
parameter names (standard `load_state_dict` behavior).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
module: The instantiated PyTorch model to inspect.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
return ModelStateMapperParallel(
|
|
21
|
+
[ModelStateMapperIdentity(key) for key in module.state_dict()]
|
|
22
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Complex state mappers are built using composition. This package provides ModelStateMapper implementations that
|
|
3
|
+
are composed of other mappers.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from .helper import filter_empty_mappers
|
|
8
|
+
from .parallel import ModelStateMapperParallel
|
|
9
|
+
from .sequential import ModelStateMapperSequential
|
|
10
|
+
from .shard import ModelStateMapperShard
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ModelStateMapperParallel",
|
|
14
|
+
"ModelStateMapperSequential",
|
|
15
|
+
"ModelStateMapperShard",
|
|
16
|
+
"filter_empty_mappers"
|
|
17
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper.abc import ModelStateMapper
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def filter_empty_mappers(mappers: Sequence[ModelStateMapper]) -> list[ModelStateMapper]:
|
|
7
|
+
"""
|
|
8
|
+
Filters out mappers that have no effect (no inputs and no outputs).
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
mappers: The list of mappers to filter.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
A new list containing only active mappers.
|
|
15
|
+
"""
|
|
16
|
+
result = []
|
|
17
|
+
for mapper in mappers:
|
|
18
|
+
for group in mapper.state_dependency_groups():
|
|
19
|
+
if len(group.inputs) > 0 or len(group.outputs) > 0:
|
|
20
|
+
result.append(mapper)
|
|
21
|
+
break
|
|
22
|
+
return result
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
6
|
+
from d9d.model_state.mapper.compose.helper import filter_empty_mappers
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelStateMapperParallel(ModelStateMapper):
|
|
10
|
+
"""
|
|
11
|
+
Executes a list of states mappers independently alongside each other.
|
|
12
|
+
|
|
13
|
+
This class aggregates multiple mappers into a single logical unit.
|
|
14
|
+
It enforces strict isolation between the mappers: no two mappers can
|
|
15
|
+
consume the same input key (input collision) or produce the same output
|
|
16
|
+
key (output collision).
|
|
17
|
+
|
|
18
|
+
During execution (`apply`), it routes the specific subset of the input dictionary
|
|
19
|
+
to the sub-mapper responsible for those keys.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, mappers: Sequence[ModelStateMapper]):
|
|
23
|
+
mappers_lst = filter_empty_mappers(mappers)
|
|
24
|
+
|
|
25
|
+
all_groups = set()
|
|
26
|
+
inputs_to_mapper = {}
|
|
27
|
+
|
|
28
|
+
seen_inputs: set[str] = set()
|
|
29
|
+
seen_outputs: set[str] = set()
|
|
30
|
+
for mapper in mappers_lst:
|
|
31
|
+
sub_groups = mapper.state_dependency_groups()
|
|
32
|
+
|
|
33
|
+
for sub_group in sub_groups:
|
|
34
|
+
if not seen_inputs.isdisjoint(sub_group.inputs):
|
|
35
|
+
raise ValueError(f"Found a colliding input group: {sub_group.inputs}")
|
|
36
|
+
seen_inputs.update(sub_group.inputs)
|
|
37
|
+
|
|
38
|
+
if not seen_outputs.isdisjoint(sub_group.outputs):
|
|
39
|
+
raise ValueError(f"Found colliding output keys: {sub_group.outputs}")
|
|
40
|
+
seen_outputs.update(sub_group.outputs)
|
|
41
|
+
|
|
42
|
+
all_groups.add(sub_group)
|
|
43
|
+
inputs_to_mapper[sub_group.inputs] = mapper
|
|
44
|
+
|
|
45
|
+
self._all_groups = frozenset(all_groups)
|
|
46
|
+
self._inputs_to_mapper = inputs_to_mapper
|
|
47
|
+
|
|
48
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
49
|
+
return self._all_groups
|
|
50
|
+
|
|
51
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
52
|
+
group_keys = frozenset(group.keys())
|
|
53
|
+
|
|
54
|
+
if group_keys not in self._inputs_to_mapper:
|
|
55
|
+
raise ValueError("Tried to run a parallel mapper with undefined group. Perhaps you sent groups that are "
|
|
56
|
+
"not isolated?")
|
|
57
|
+
|
|
58
|
+
return self._inputs_to_mapper[group_keys].apply(group)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from collections.abc import Set as AbstractSet
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
7
|
+
from d9d.model_state.mapper.compose.helper import filter_empty_mappers
|
|
8
|
+
from d9d.model_state.mapper.compose.parallel import ModelStateMapperParallel
|
|
9
|
+
from d9d.model_state.mapper.leaf.identity import ModelStateMapperIdentity
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelStateMapperSequential(ModelStateMapper):
|
|
13
|
+
"""
|
|
14
|
+
Executes a list of mappers in a specific sequence (pipeline).
|
|
15
|
+
|
|
16
|
+
This class manages the data flow from one mapper to the next. It abstracts
|
|
17
|
+
away intermediate states, exposing only the inputs required by the first
|
|
18
|
+
relevant stage and the outputs produced by the final relevant stage.
|
|
19
|
+
|
|
20
|
+
Key Features:
|
|
21
|
+
|
|
22
|
+
1. **Gap Filling**: Automatically injects `Identity` mappers if a tensor needs
|
|
23
|
+
to pass through a stage without modification to reach a later stage or
|
|
24
|
+
the final output.
|
|
25
|
+
|
|
26
|
+
2. **Group Merging**: Computes the net dependency graph. If Stage A requires 'x'
|
|
27
|
+
and produces 'y', and Stage B requires 'y' and produces 'z', the
|
|
28
|
+
Sequential mapper reports a single group `{x} -> {z}`.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, mappers: list[ModelStateMapper]):
|
|
32
|
+
mappers = filter_empty_mappers(mappers)
|
|
33
|
+
if not mappers:
|
|
34
|
+
raise ValueError("Mappers list cannot be empty.")
|
|
35
|
+
|
|
36
|
+
mappers = self._fill_gaps(mappers)
|
|
37
|
+
|
|
38
|
+
self._groups = self._compute_pipeline_groups(mappers)
|
|
39
|
+
self._mappers = mappers
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def _fill_gaps(mappers: list[ModelStateMapper]) -> list[ModelStateMapper]:
|
|
43
|
+
mappers = mappers.copy()
|
|
44
|
+
|
|
45
|
+
# propagate inputs from bottom to top
|
|
46
|
+
for stage_i in range(1, len(mappers))[::-1]:
|
|
47
|
+
groups_current = mappers[stage_i].state_dependency_groups()
|
|
48
|
+
groups_prev = mappers[stage_i - 1].state_dependency_groups()
|
|
49
|
+
current_stage_requires = frozenset.union(*(x.inputs for x in groups_current))
|
|
50
|
+
prev_stage_produces = frozenset.union(*(x.outputs for x in groups_prev))
|
|
51
|
+
|
|
52
|
+
needs_to_pass_through = current_stage_requires - prev_stage_produces
|
|
53
|
+
|
|
54
|
+
mappers[stage_i - 1] = ModelStateMapperParallel(
|
|
55
|
+
[mappers[stage_i - 1]] + [ModelStateMapperIdentity(x) for x in needs_to_pass_through]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# propagate outputs from top to bottom
|
|
59
|
+
for stage_i in range(0, len(mappers) - 1):
|
|
60
|
+
groups_current = mappers[stage_i].state_dependency_groups()
|
|
61
|
+
groups_next = mappers[stage_i + 1].state_dependency_groups()
|
|
62
|
+
current_stage_produces = frozenset.union(*(x.outputs for x in groups_current))
|
|
63
|
+
next_stage_requires = frozenset.union(*(x.inputs for x in groups_next))
|
|
64
|
+
|
|
65
|
+
needs_to_pass_through = current_stage_produces - next_stage_requires
|
|
66
|
+
|
|
67
|
+
mappers[stage_i + 1] = ModelStateMapperParallel(
|
|
68
|
+
[mappers[stage_i + 1]] + [ModelStateMapperIdentity(x) for x in needs_to_pass_through]
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return mappers
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _compute_pipeline_groups(mappers: list[ModelStateMapper]) -> frozenset[StateGroup]:
|
|
75
|
+
outputs_depend_on_inputs = {}
|
|
76
|
+
|
|
77
|
+
# given a fully connected graph, we can just go upwards
|
|
78
|
+
for last_group_traced in mappers[-1].state_dependency_groups():
|
|
79
|
+
required_inputs = last_group_traced.inputs
|
|
80
|
+
|
|
81
|
+
for mapper_i in range(0, len(mappers) - 1)[::-1]:
|
|
82
|
+
next_visit_groups = [x for x in mappers[mapper_i].state_dependency_groups()
|
|
83
|
+
if not x.outputs.isdisjoint(required_inputs)]
|
|
84
|
+
|
|
85
|
+
required_inputs = frozenset.union(*(x.inputs for x in next_visit_groups))
|
|
86
|
+
|
|
87
|
+
outputs_depend_on_inputs[last_group_traced.outputs] = required_inputs
|
|
88
|
+
|
|
89
|
+
return ModelStateMapperSequential._merge_groups(list(outputs_depend_on_inputs.items()))
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _merge_groups(groups: Sequence[tuple[AbstractSet[str], AbstractSet[str]]]) -> frozenset[StateGroup]:
|
|
93
|
+
saved_groups: list[tuple[set[str], set[str]]] = []
|
|
94
|
+
|
|
95
|
+
saved_groups_modified = True
|
|
96
|
+
while saved_groups_modified:
|
|
97
|
+
saved_groups_modified = False
|
|
98
|
+
for output_names, input_names in groups:
|
|
99
|
+
was_new_group_created = False
|
|
100
|
+
for group in saved_groups:
|
|
101
|
+
if group[0].intersection(input_names) or group[1].intersection(output_names):
|
|
102
|
+
group[0].update(input_names)
|
|
103
|
+
group[1].update(output_names)
|
|
104
|
+
was_new_group_created = True
|
|
105
|
+
saved_groups_modified = True
|
|
106
|
+
|
|
107
|
+
if not was_new_group_created:
|
|
108
|
+
saved_groups.append((set(input_names), set(output_names)))
|
|
109
|
+
|
|
110
|
+
groups = saved_groups
|
|
111
|
+
saved_groups = []
|
|
112
|
+
|
|
113
|
+
return frozenset(StateGroup(inputs=frozenset(x[0]), outputs=frozenset(x[1])) for x in groups)
|
|
114
|
+
|
|
115
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
116
|
+
return self._groups
|
|
117
|
+
|
|
118
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
119
|
+
current_state = group
|
|
120
|
+
next_state = {}
|
|
121
|
+
for mapper in self._mappers:
|
|
122
|
+
for deps in mapper.state_dependency_groups():
|
|
123
|
+
if not deps.inputs <= current_state.keys():
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
next_state.update(mapper.apply({k: v for k, v in current_state.items() if k in deps.inputs}))
|
|
127
|
+
|
|
128
|
+
current_state = next_state
|
|
129
|
+
next_state = {}
|
|
130
|
+
|
|
131
|
+
return current_state
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelStateMapperShard(ModelStateMapper):
|
|
7
|
+
"""
|
|
8
|
+
Wraps another state mapper and restricts its execution to a specific subset (shard)
|
|
9
|
+
of dependency groups.
|
|
10
|
+
|
|
11
|
+
This is primarily used for parallelizing model loading across multiple processes
|
|
12
|
+
or nodes. By assigning a different `current_shard` index to each process,
|
|
13
|
+
the total set of tensors required by the `sub_mapper` is split evenly,
|
|
14
|
+
preventing every process from loading the entire checkpoint.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, sub_mapper: ModelStateMapper, total_shards: int, current_shard: int):
|
|
18
|
+
self._groups = self._shard_groups(
|
|
19
|
+
sub_mapper.state_dependency_groups(),
|
|
20
|
+
n_shards=total_shards, shard=current_shard
|
|
21
|
+
)
|
|
22
|
+
self._sub_mapper = sub_mapper
|
|
23
|
+
self._total_shards = total_shards
|
|
24
|
+
self._current_shard = current_shard
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def _shard_groups(groups: frozenset[StateGroup], n_shards: int, shard: int) -> frozenset[StateGroup]:
|
|
28
|
+
groups_sorted = sorted(groups, key=lambda x: sorted(x.inputs))
|
|
29
|
+
groups_shard = [x for i, x in enumerate(groups_sorted) if i % n_shards == shard]
|
|
30
|
+
return frozenset(groups_shard)
|
|
31
|
+
|
|
32
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
33
|
+
return self._groups
|
|
34
|
+
|
|
35
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
36
|
+
return self._sub_mapper.apply(group)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This package provides leaf mapper implementations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .dtensor import ModelStateMapperDistribute, ModelStateMapperGatherFullTensor
|
|
6
|
+
from .identity import ModelStateMapperIdentity
|
|
7
|
+
from .rename import ModelStateMapperRename
|
|
8
|
+
from .select_child import ModelStateMapperSelectChildModules
|
|
9
|
+
from .stack import ModelStateMapperStackTensors
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ModelStateMapperDistribute",
|
|
13
|
+
"ModelStateMapperGatherFullTensor",
|
|
14
|
+
"ModelStateMapperIdentity",
|
|
15
|
+
"ModelStateMapperRename",
|
|
16
|
+
"ModelStateMapperSelectChildModules",
|
|
17
|
+
"ModelStateMapperStackTensors",
|
|
18
|
+
]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch._C._distributed import Placement
|
|
5
|
+
from torch.distributed import DeviceMesh
|
|
6
|
+
from torch.distributed.tensor import DTensor, distribute_tensor
|
|
7
|
+
|
|
8
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelStateMapperDistribute(ModelStateMapper):
|
|
12
|
+
"""
|
|
13
|
+
Converts a single local Tensor object into a DTensor object with specified
|
|
14
|
+
`device_mesh` and `placements`.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, name: str, device_mesh: DeviceMesh | None, placements: Sequence[Placement] | None):
|
|
18
|
+
self._name = name
|
|
19
|
+
|
|
20
|
+
self._device_mesh = device_mesh
|
|
21
|
+
self._placements = placements
|
|
22
|
+
|
|
23
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
24
|
+
return frozenset([StateGroup(inputs=frozenset([self._name]), outputs=frozenset([self._name]))])
|
|
25
|
+
|
|
26
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
27
|
+
return {
|
|
28
|
+
self._name: distribute_tensor(
|
|
29
|
+
group[self._name],
|
|
30
|
+
device_mesh=self._device_mesh,
|
|
31
|
+
placements=self._placements,
|
|
32
|
+
src_data_rank=None # do not communicate here
|
|
33
|
+
)
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ModelStateMapperGatherFullTensor(ModelStateMapper):
|
|
38
|
+
"""
|
|
39
|
+
Gathers a single DTensor object into a full Tensor object.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, name: str):
|
|
43
|
+
self._name = name
|
|
44
|
+
|
|
45
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
46
|
+
return frozenset([StateGroup(inputs=frozenset([self._name]), outputs=frozenset([self._name]))])
|
|
47
|
+
|
|
48
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
49
|
+
tensor = group[self._name]
|
|
50
|
+
|
|
51
|
+
if not isinstance(tensor, DTensor):
|
|
52
|
+
raise ValueError("Cannot gather anything but DTensor")
|
|
53
|
+
|
|
54
|
+
return {
|
|
55
|
+
self._name: tensor.full_tensor()
|
|
56
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelStateMapperIdentity(ModelStateMapper):
|
|
7
|
+
"""
|
|
8
|
+
Passes a single state tensor through unchanged.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, name: str):
|
|
12
|
+
self._name = name
|
|
13
|
+
|
|
14
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
15
|
+
return frozenset([
|
|
16
|
+
StateGroup(
|
|
17
|
+
inputs=frozenset([self._name]),
|
|
18
|
+
outputs=frozenset([self._name])
|
|
19
|
+
)
|
|
20
|
+
])
|
|
21
|
+
|
|
22
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
23
|
+
return group
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelStateMapperRename(ModelStateMapper):
|
|
7
|
+
"""
|
|
8
|
+
Renames a single state tensor from `name_from` to `name_to`.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, name_from: str, name_to: str):
|
|
12
|
+
self._name_from = name_from
|
|
13
|
+
self._name_to = name_to
|
|
14
|
+
|
|
15
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
16
|
+
return frozenset([
|
|
17
|
+
StateGroup(
|
|
18
|
+
inputs=frozenset([self._name_from]),
|
|
19
|
+
outputs=frozenset([self._name_to])
|
|
20
|
+
)
|
|
21
|
+
])
|
|
22
|
+
|
|
23
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
24
|
+
return {
|
|
25
|
+
self._name_to: group[self._name_from]
|
|
26
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelStateMapperSelectChildModules(ModelStateMapper):
|
|
7
|
+
"""
|
|
8
|
+
Selects a set of keys belonging to a specific parent module (prefix) and
|
|
9
|
+
renames them by removing that prefix.
|
|
10
|
+
|
|
11
|
+
This is effectively a batch rename operation that "hoists" parameters
|
|
12
|
+
from a submodule scope to the current scope.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, base_names: list[str], parent_name: str):
|
|
16
|
+
self._base_names = base_names
|
|
17
|
+
self._parent_prefix = f"{parent_name}."
|
|
18
|
+
|
|
19
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
20
|
+
return frozenset([
|
|
21
|
+
StateGroup(
|
|
22
|
+
inputs=frozenset([self._parent_prefix + name]),
|
|
23
|
+
outputs=frozenset([name])
|
|
24
|
+
)
|
|
25
|
+
for name in self._base_names
|
|
26
|
+
])
|
|
27
|
+
|
|
28
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
29
|
+
name, value = next(iter(group.items()))
|
|
30
|
+
if name.startswith(self._parent_prefix):
|
|
31
|
+
return {
|
|
32
|
+
name[len(self._parent_prefix):]: value
|
|
33
|
+
}
|
|
34
|
+
else:
|
|
35
|
+
return {
|
|
36
|
+
|
|
37
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelStateMapperStackTensors(ModelStateMapper):
|
|
7
|
+
"""
|
|
8
|
+
Stacks multiple input tensors with names `source_names` into a single output tensor with name `target_name`
|
|
9
|
+
producing new `stack_dim` dimension.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, source_names: list[str], target_name: str, stack_dim: int):
|
|
13
|
+
self._source_names = source_names
|
|
14
|
+
self._target_name = target_name
|
|
15
|
+
self._stack_dim = stack_dim
|
|
16
|
+
|
|
17
|
+
def state_dependency_groups(self) -> frozenset[StateGroup]:
|
|
18
|
+
return frozenset([
|
|
19
|
+
StateGroup(
|
|
20
|
+
inputs=frozenset(self._source_names),
|
|
21
|
+
outputs=frozenset([self._target_name])
|
|
22
|
+
)
|
|
23
|
+
])
|
|
24
|
+
|
|
25
|
+
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
26
|
+
source_tensors = [group[name] for name in self._source_names]
|
|
27
|
+
return {
|
|
28
|
+
self._target_name: torch.stack(source_tensors, dim=self._stack_dim)
|
|
29
|
+
}
|
d9d/module/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from typing import Protocol
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@typing.runtime_checkable
|
|
6
|
+
class ModuleLateInit(Protocol):
|
|
7
|
+
"""Protocol for modules that support late parameter initialization."""
|
|
8
|
+
|
|
9
|
+
def reset_parameters(self):
|
|
10
|
+
"""Resets the module parameters (i.e. performs random initialization)."""
|
|
File without changes
|