d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from d9d.core.dist_context import DENSE_DOMAIN, EXPERT_DOMAIN, DistributedContext
|
|
2
|
+
from d9d.module.model.qwen3_moe import Qwen3MoEForCausalLM, Qwen3MoEModel
|
|
3
|
+
from d9d.module.parallelism.api import parallelize_expert_parallel, parallelize_hsdp
|
|
4
|
+
from d9d.pipelining.api import PipelineStageInfo
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def parallelize_qwen3_moe_model(
|
|
8
|
+
dist_context: DistributedContext,
|
|
9
|
+
model: Qwen3MoEModel,
|
|
10
|
+
stage: PipelineStageInfo
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Parallelizes the base Qwen3 MoE model components.
|
|
14
|
+
|
|
15
|
+
This function configures the model layers for distributed execution within a pipeline
|
|
16
|
+
stage. It applies Hybrid Sharded Data Parallelism (HSDP) to dense components (embeddings,
|
|
17
|
+
norms, attention) and Expert Parallelism (EP) to the Mixture-of-Experts (MLP) layers.
|
|
18
|
+
|
|
19
|
+
Current usage constraints:
|
|
20
|
+
* Tensor Parallelism is not supported (we may implement it later).
|
|
21
|
+
* Context Parallelism is not supported (we will implement it later).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
dist_context: The distributed context.
|
|
25
|
+
model: The Qwen3 MoE base model to parallelize.
|
|
26
|
+
stage: Information about the current pipeline stage.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If Tensor Parallel or Context Parallel is enabled in the context.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
dims = dist_context.mesh_params
|
|
33
|
+
dense_mesh = dist_context.mesh_for(DENSE_DOMAIN)
|
|
34
|
+
expert_mesh = dist_context.mesh_for(EXPERT_DOMAIN)
|
|
35
|
+
|
|
36
|
+
if dims.has_tensor_parallel:
|
|
37
|
+
raise ValueError("Tensor Parallel currently is not supported for this model.")
|
|
38
|
+
if dims.has_context_parallel_replicate or dims.has_context_parallel_shard:
|
|
39
|
+
raise ValueError("Context Parallel currently is not supported for this model.")
|
|
40
|
+
|
|
41
|
+
if stage.is_current_stage_first:
|
|
42
|
+
parallelize_hsdp(
|
|
43
|
+
model.embed_tokens,
|
|
44
|
+
mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"]
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if stage.is_current_stage_last:
|
|
48
|
+
parallelize_hsdp(
|
|
49
|
+
model.norm,
|
|
50
|
+
mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
for layer in model.layers.values():
|
|
54
|
+
parallelize_expert_parallel(
|
|
55
|
+
layer.mlp,
|
|
56
|
+
mesh_experts=expert_mesh["ep_replicate", "ep_shard"]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
parallelize_hsdp(
|
|
60
|
+
layer.self_attn,
|
|
61
|
+
mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
|
|
62
|
+
)
|
|
63
|
+
parallelize_hsdp(
|
|
64
|
+
layer.input_layernorm,
|
|
65
|
+
mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
|
|
66
|
+
)
|
|
67
|
+
parallelize_hsdp(
|
|
68
|
+
layer.post_attention_layernorm,
|
|
69
|
+
mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def parallelize_qwen3_moe_for_causal_lm(
|
|
74
|
+
dist_context: DistributedContext,
|
|
75
|
+
model: Qwen3MoEForCausalLM,
|
|
76
|
+
stage: PipelineStageInfo
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Parallelizes the Qwen3 MoE Causal LM model.
|
|
80
|
+
|
|
81
|
+
This function delegates backbone parallelization to ``parallelize_qwen3_moe_model``
|
|
82
|
+
and additionally configures the language model head with Hybrid Sharded Data
|
|
83
|
+
Parallelism (HSDP).
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
dist_context: The distributed context containing device meshes and topology info.
|
|
87
|
+
model: The Qwen3 MoE Causal LM model to parallelize.
|
|
88
|
+
stage: Information about the current pipeline stage.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
dense_mesh = dist_context.mesh_for(DENSE_DOMAIN)
|
|
92
|
+
|
|
93
|
+
parallelize_qwen3_moe_model(dist_context, model.model, stage)
|
|
94
|
+
|
|
95
|
+
if stage.is_current_stage_last:
|
|
96
|
+
parallelize_hsdp(
|
|
97
|
+
model.lm_head,
|
|
98
|
+
mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
|
|
99
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
from torch.distributed import DeviceMesh
|
|
3
|
+
from torch.distributed.tensor import (
|
|
4
|
+
Replicate,
|
|
5
|
+
Shard,
|
|
6
|
+
distribute_module,
|
|
7
|
+
distribute_tensor,
|
|
8
|
+
)
|
|
9
|
+
from torch.distributed.tensor.parallel import ParallelStyle
|
|
10
|
+
|
|
11
|
+
from d9d.module.block.moe import GroupedLinear, MoELayer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ShardMoESparseExpertsParallel(ParallelStyle):
|
|
15
|
+
"""
|
|
16
|
+
Parallel style that shards MoE experts across a specific mesh dimension.
|
|
17
|
+
|
|
18
|
+
This style is designed for ``MoELayer`` instances using ``GroupedLinear`` for experts.
|
|
19
|
+
It splits the experts across the specified
|
|
20
|
+
dimension of the device mesh (Expert Parallelism). Other dimensions in the
|
|
21
|
+
mesh treat the parameters as Replicated.
|
|
22
|
+
|
|
23
|
+
It also initializes the necessary distributed communication groups within the
|
|
24
|
+
MoE layer to handle token dispatching.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, shard_dim_name: str):
|
|
28
|
+
self._shard_dim_name = shard_dim_name
|
|
29
|
+
|
|
30
|
+
def _partition_experts(self, module_name: str, mod: nn.Module, device_mesh: DeviceMesh):
|
|
31
|
+
if not isinstance(mod, GroupedLinear):
|
|
32
|
+
raise TypeError("This plan should be applied only on GroupedLinear")
|
|
33
|
+
|
|
34
|
+
mesh_dim_names = device_mesh.mesh_dim_names
|
|
35
|
+
|
|
36
|
+
if mesh_dim_names is None:
|
|
37
|
+
raise ValueError("This plan should be applied only on named DeviceMeshes")
|
|
38
|
+
|
|
39
|
+
placements = [
|
|
40
|
+
Shard(0) if dim_name == self._shard_dim_name else Replicate()
|
|
41
|
+
for dim_name
|
|
42
|
+
in mesh_dim_names
|
|
43
|
+
]
|
|
44
|
+
weight = nn.Parameter(
|
|
45
|
+
distribute_tensor(mod.weight, device_mesh, placements),
|
|
46
|
+
requires_grad=mod.weight.requires_grad
|
|
47
|
+
)
|
|
48
|
+
mod.weight = weight
|
|
49
|
+
|
|
50
|
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
51
|
+
if not isinstance(module, MoELayer):
|
|
52
|
+
raise TypeError("This plan should be applied only on MoELayer")
|
|
53
|
+
|
|
54
|
+
module.enable_distributed_communicator(device_mesh.get_group(self._shard_dim_name))
|
|
55
|
+
|
|
56
|
+
for submod in module.modules():
|
|
57
|
+
if isinstance(submod, GroupedLinear):
|
|
58
|
+
distribute_module(submod, device_mesh, self._partition_experts)
|
|
59
|
+
|
|
60
|
+
return module
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
from torch.distributed import DeviceMesh
|
|
5
|
+
from torch.distributed.tensor import Placement, distribute_module, distribute_tensor
|
|
6
|
+
from torch.distributed.tensor.parallel import ParallelStyle
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _build_to_local_patched_class(
|
|
10
|
+
module: nn.Module,
|
|
11
|
+
grad_placement: tuple[Placement, ...],
|
|
12
|
+
param_names: list[str]
|
|
13
|
+
) -> type:
|
|
14
|
+
param_name_to_property = {
|
|
15
|
+
param_name: property(
|
|
16
|
+
lambda self, pn=param_name: self._parameters[pn].to_local(grad_placements=grad_placement) # type: ignore
|
|
17
|
+
)
|
|
18
|
+
for param_name in param_names
|
|
19
|
+
}
|
|
20
|
+
return type(
|
|
21
|
+
f"Replicate{module.__class__.__name__}",
|
|
22
|
+
(module.__class__,),
|
|
23
|
+
param_name_to_property,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _ModulePatch:
|
|
28
|
+
def __init__(self, class_mapper: dict[str, type]):
|
|
29
|
+
self._class_mapper = class_mapper
|
|
30
|
+
|
|
31
|
+
def __call__(self, mod: nn.Module, *args: Any, **kwargs: Any):
|
|
32
|
+
for submod_name, submod in mod.named_modules():
|
|
33
|
+
submod.__class__ = self._class_mapper[submod_name]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ToLocalParallel(ParallelStyle):
|
|
37
|
+
"""
|
|
38
|
+
Parallel style that distributes parameters and gradients but executes with local tensors.
|
|
39
|
+
|
|
40
|
+
This style wraps standard tensor distribution (via ``DTensor``) but injects
|
|
41
|
+
runtime hooks to temporarily unwrap ``DTensor`` parameters into local ``torch.Tensor``
|
|
42
|
+
during the forward pass.
|
|
43
|
+
|
|
44
|
+
This is useful for parallel strategies (like Replicate)
|
|
45
|
+
where the underlying calculation logic is not DTensor-aware, but the parameters must remain
|
|
46
|
+
distributed for gradient synchronization and for distributed checkpointing.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, param_placement: tuple[Placement, ...], grad_placement: tuple[Placement, ...]):
|
|
50
|
+
"""
|
|
51
|
+
Constructs ToLocalParallel object.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
param_placement: Tuple of placements defining how parameters are distributed.
|
|
55
|
+
grad_placement: Tuple of placements defining how gradients are synchronized.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
self._grad_placement = grad_placement
|
|
59
|
+
self._param_placement = param_placement
|
|
60
|
+
|
|
61
|
+
def _distribute_params(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
|
|
62
|
+
for param_name, param in module.named_parameters(recurse=False):
|
|
63
|
+
new_param = nn.Parameter(
|
|
64
|
+
distribute_tensor(param.data, device_mesh, self._param_placement),
|
|
65
|
+
requires_grad=param.requires_grad
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
module.register_parameter(param_name, new_param)
|
|
69
|
+
|
|
70
|
+
def _apply(self, master_module: nn.Module, device_mesh: DeviceMesh):
|
|
71
|
+
patched_classes = {}
|
|
72
|
+
original_classes = {}
|
|
73
|
+
|
|
74
|
+
for submod_name, submod in master_module.named_modules():
|
|
75
|
+
param_names = [name for name, p in submod.named_parameters(recurse=False)]
|
|
76
|
+
patched_classes[submod_name] = _build_to_local_patched_class(submod, self._grad_placement, param_names)
|
|
77
|
+
original_classes[submod_name] = submod.__class__
|
|
78
|
+
|
|
79
|
+
distribute_module(
|
|
80
|
+
submod,
|
|
81
|
+
device_mesh,
|
|
82
|
+
self._distribute_params
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
master_module.register_forward_pre_hook(_ModulePatch(patched_classes))
|
|
86
|
+
master_module.register_forward_hook(_ModulePatch(original_classes))
|
d9d/optim/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.distributed.tensor import DTensor
|
|
5
|
+
from torch.optim import Optimizer
|
|
6
|
+
from torch.optim.optimizer import ParamsT, StateDict
|
|
7
|
+
|
|
8
|
+
from d9d.kernel.stochastic import adamw_stochastic_bf16_
|
|
9
|
+
|
|
10
|
+
_GENERATOR_STATE_KEY = "_d9d_generator_state"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _new_buffer(p: torch.Tensor, dtype_override: torch.dtype) -> torch.Tensor:
|
|
14
|
+
if isinstance(p, DTensor):
|
|
15
|
+
local_p = p.to_local()
|
|
16
|
+
else:
|
|
17
|
+
local_p = p
|
|
18
|
+
|
|
19
|
+
out = torch.zeros_like(local_p, dtype=dtype_override).contiguous()
|
|
20
|
+
|
|
21
|
+
if isinstance(p, DTensor):
|
|
22
|
+
out = DTensor.from_local(
|
|
23
|
+
local_tensor=out,
|
|
24
|
+
device_mesh=p.device_mesh,
|
|
25
|
+
placements=p.placements,
|
|
26
|
+
run_check=False,
|
|
27
|
+
shape=p.shape,
|
|
28
|
+
stride=p.stride(),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
return out
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _tensor_to_local(tensor: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
if isinstance(tensor, DTensor):
|
|
36
|
+
return tensor.to_local()
|
|
37
|
+
return tensor
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class StochasticAdamW(Optimizer):
|
|
41
|
+
"""Implements the AdamW algorithm with Stochastic Rounding.
|
|
42
|
+
|
|
43
|
+
This optimizer is designed to handle stochastic rounding primarily for BF16 training,
|
|
44
|
+
leveraging a custom kernel.
|
|
45
|
+
|
|
46
|
+
Parameters must be in BF16. Gradients could be both in BF16 and FP32.
|
|
47
|
+
|
|
48
|
+
It natively supports PyTorch distributed ``DTensor`` parameters.
|
|
49
|
+
|
|
50
|
+
It maintains its own random number generator state to ensure reproducibility.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
params: ParamsT,
|
|
56
|
+
lr: float,
|
|
57
|
+
betas: tuple[float, float] = (0.9, 0.999),
|
|
58
|
+
eps: float = 1e-8,
|
|
59
|
+
weight_decay: float = 1e-2,
|
|
60
|
+
generator: torch.Generator | None = None,
|
|
61
|
+
state_dtype: torch.dtype = torch.float32,
|
|
62
|
+
):
|
|
63
|
+
"""Constructs a new StochasticAdamW optimizer.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
params: Iterable of parameters to optimize or dicts defining parameter groups.
|
|
67
|
+
lr: Learning rate.
|
|
68
|
+
betas: Coefficients used for computing running averages of gradient and its square.
|
|
69
|
+
eps: Term added to the denominator to improve numerical stability.
|
|
70
|
+
weight_decay: Weight decay coefficient.
|
|
71
|
+
generator: Pseudorandom number generator for stochastic rounding. If None,
|
|
72
|
+
a new generator is created and seeded from the main PyTorch generator.
|
|
73
|
+
state_dtype: Data Type to use for the optimizer states.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
if lr <= 0:
|
|
77
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
78
|
+
if eps <= 0:
|
|
79
|
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
80
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
81
|
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
82
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
83
|
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
84
|
+
if weight_decay <= 0:
|
|
85
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
86
|
+
|
|
87
|
+
if generator is None:
|
|
88
|
+
generator = torch.Generator(device="cpu")
|
|
89
|
+
# make the generator fork from pytorch's main generator
|
|
90
|
+
seed = cast(int, torch.randint(0, 2**32, (1,)).item())
|
|
91
|
+
generator.manual_seed(seed)
|
|
92
|
+
|
|
93
|
+
self._generator = generator
|
|
94
|
+
|
|
95
|
+
defaults = {
|
|
96
|
+
"lr": lr,
|
|
97
|
+
"betas": betas,
|
|
98
|
+
"eps": eps,
|
|
99
|
+
"weight_decay": weight_decay,
|
|
100
|
+
"state_dtype": state_dtype
|
|
101
|
+
}
|
|
102
|
+
super().__init__(params, defaults)
|
|
103
|
+
|
|
104
|
+
def state_dict(self) -> StateDict:
|
|
105
|
+
state_dict = super().state_dict()
|
|
106
|
+
state_dict[_GENERATOR_STATE_KEY] = self._generator.get_state()
|
|
107
|
+
return state_dict
|
|
108
|
+
|
|
109
|
+
def load_state_dict(self, state_dict: StateDict) -> None:
|
|
110
|
+
if _GENERATOR_STATE_KEY in state_dict:
|
|
111
|
+
self._generator.set_state(state_dict.pop(_GENERATOR_STATE_KEY))
|
|
112
|
+
super().load_state_dict(state_dict)
|
|
113
|
+
|
|
114
|
+
@torch.no_grad()
|
|
115
|
+
def step(self, closure: None = None) -> None: # type: ignore[override]
|
|
116
|
+
if closure is not None:
|
|
117
|
+
raise ValueError("Closure is not supported")
|
|
118
|
+
|
|
119
|
+
for group in self.param_groups:
|
|
120
|
+
lr = group["lr"]
|
|
121
|
+
beta1, beta2 = group["betas"]
|
|
122
|
+
eps = group["eps"]
|
|
123
|
+
weight_decay = group["weight_decay"]
|
|
124
|
+
state_dtype = group["state_dtype"]
|
|
125
|
+
|
|
126
|
+
for p in group["params"]:
|
|
127
|
+
if p.grad is None:
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
grad = p.grad
|
|
131
|
+
if grad.is_sparse:
|
|
132
|
+
raise RuntimeError("StochasticAdamW does not support sparse gradients")
|
|
133
|
+
|
|
134
|
+
state = self.state[p]
|
|
135
|
+
|
|
136
|
+
# State Initialization
|
|
137
|
+
if len(state) == 0:
|
|
138
|
+
state["step"] = 0
|
|
139
|
+
state["exp_avg"] = _new_buffer(p, dtype_override=state_dtype)
|
|
140
|
+
state["exp_avg_sq"] = _new_buffer(p, dtype_override=state_dtype)
|
|
141
|
+
|
|
142
|
+
state["step"] += 1
|
|
143
|
+
exp_avg = state["exp_avg"]
|
|
144
|
+
exp_avg_sq = state["exp_avg_sq"]
|
|
145
|
+
|
|
146
|
+
adamw_stochastic_bf16_(
|
|
147
|
+
params=_tensor_to_local(p),
|
|
148
|
+
grads=_tensor_to_local(grad),
|
|
149
|
+
exp_avg=_tensor_to_local(exp_avg),
|
|
150
|
+
exp_avg_sq=_tensor_to_local(exp_avg_sq),
|
|
151
|
+
lr=lr,
|
|
152
|
+
beta1=beta1,
|
|
153
|
+
beta2=beta2,
|
|
154
|
+
eps=eps,
|
|
155
|
+
weight_decay=weight_decay,
|
|
156
|
+
step=state["step"],
|
|
157
|
+
generator=self._generator
|
|
158
|
+
)
|
d9d/peft/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Provides core logic for PEFT (Parameter-Efficient Fine-Tuning) application and base definitions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .applicator import inject_peft_and_freeze, merge_peft
|
|
6
|
+
from .base import PeftInjectionResult, PeftMethod
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"PeftInjectionResult",
|
|
10
|
+
"PeftMethod",
|
|
11
|
+
"inject_peft_and_freeze",
|
|
12
|
+
"merge_peft"
|
|
13
|
+
]
|
d9d/peft/all/__init__.py
ADDED
d9d/peft/all/config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from d9d.peft.full_tune.config import FullTuneConfig
|
|
6
|
+
from d9d.peft.lora.config import LoRAConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PeftStackConfig(BaseModel):
|
|
10
|
+
"""
|
|
11
|
+
Configuration for applying a stack of multiple PEFT methods sequentially.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
kind: Discriminator field, always "stack".
|
|
15
|
+
methods: A list of specific PEFT configurations (e.g., LoRA, FullTune) to apply in order.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
kind: Literal["stack"] = "stack"
|
|
19
|
+
|
|
20
|
+
methods: list["AnyPeftConfig"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
AnyPeftConfig = Annotated[
|
|
24
|
+
LoRAConfig
|
|
25
|
+
| FullTuneConfig
|
|
26
|
+
| PeftStackConfig,
|
|
27
|
+
Field(discriminator="kind"),
|
|
28
|
+
]
|
|
29
|
+
"""
|
|
30
|
+
Union type representing any valid PEFT configuration, discriminated by the 'kind' field.
|
|
31
|
+
"""
|
d9d/peft/all/method.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from typing import Self, cast
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from ..all.config import PeftStackConfig
|
|
7
|
+
from ..base import PeftInjectionResult, PeftMethod, TConfig
|
|
8
|
+
from ..full_tune.config import FullTuneConfig
|
|
9
|
+
from ..full_tune.method import FullTune
|
|
10
|
+
from ..lora.config import LoRAConfig
|
|
11
|
+
from ..lora.method import LoRA
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PeftStack(PeftMethod[PeftStackConfig]):
|
|
15
|
+
"""
|
|
16
|
+
A composite PEFT method that applies a list of methods sequentially.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, methods: list[PeftMethod]):
|
|
20
|
+
"""
|
|
21
|
+
Constructs a PeftStack object.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
methods: A list of instantiated PEFT methods to apply in order.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
self._methods = methods
|
|
28
|
+
|
|
29
|
+
def inject(self, module: nn.Module) -> PeftInjectionResult:
|
|
30
|
+
params_to_train = []
|
|
31
|
+
state_mappers = []
|
|
32
|
+
|
|
33
|
+
for method in self._methods:
|
|
34
|
+
result = method.inject(module)
|
|
35
|
+
params_to_train.extend(result.parameters_to_train)
|
|
36
|
+
state_mappers.extend(result.load_state_mappers)
|
|
37
|
+
|
|
38
|
+
return PeftInjectionResult(
|
|
39
|
+
parameters_to_train=params_to_train,
|
|
40
|
+
load_state_mappers=state_mappers
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def merge(self, module: nn.Module):
|
|
44
|
+
for method in self._methods[::-1]:
|
|
45
|
+
method.merge(module)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_config(cls, config: PeftStackConfig) -> Self:
|
|
49
|
+
methods = []
|
|
50
|
+
|
|
51
|
+
for method in config.methods:
|
|
52
|
+
methods.append(peft_method_from_config(method))
|
|
53
|
+
|
|
54
|
+
return cls(methods)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
_PEFT_CONFIG_MAP: dict[type[BaseModel], type[PeftMethod]] = {
|
|
58
|
+
LoRAConfig: LoRA,
|
|
59
|
+
FullTuneConfig: FullTune,
|
|
60
|
+
PeftStackConfig: PeftStack
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def peft_method_from_config(config: TConfig) -> PeftMethod[TConfig]:
|
|
65
|
+
"""
|
|
66
|
+
Factory function to instantiate the correct PeftMethod based on the configuration type.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config: A specific PEFT configuration object (e.g., LoRAConfig).
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The corresponding method instance.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
method_cls = cast(type[PeftMethod[TConfig]], _PEFT_CONFIG_MAP[type(config)])
|
|
76
|
+
return method_cls.from_config(config)
|
d9d/peft/applicator.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
4
|
+
from d9d.model_state.mapper.compose import ModelStateMapperParallel
|
|
5
|
+
|
|
6
|
+
from .base import PeftMethod
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def inject_peft_and_freeze(method: PeftMethod, module: nn.Module) -> ModelStateMapper:
|
|
10
|
+
"""
|
|
11
|
+
Applies a PEFT method to a module, freezes non-trained parameters, and prepares state mapping.
|
|
12
|
+
|
|
13
|
+
This function performs three main steps:
|
|
14
|
+
|
|
15
|
+
1. Sets `requires_grad=False` for all parameters in the module.
|
|
16
|
+
2. Calls the method's `inject` to modify the model structure.
|
|
17
|
+
3. Sets `requires_grad=True` for the parameters returned by the injection result.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
method: The PEFT method strategy to apply.
|
|
21
|
+
module: The PyTorch module to modify.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A ModelStateMapper capable of loading checkpoint weights into the modified structure.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
for param in module.parameters():
|
|
28
|
+
param.requires_grad = False
|
|
29
|
+
|
|
30
|
+
result = method.inject(module)
|
|
31
|
+
|
|
32
|
+
for param in result.parameters_to_train:
|
|
33
|
+
param.requires_grad = True
|
|
34
|
+
|
|
35
|
+
return ModelStateMapperParallel(result.load_state_mappers)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def merge_peft(method: PeftMethod, module: nn.Module):
|
|
39
|
+
"""
|
|
40
|
+
Merges PEFT adaptations back into the base model weights.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
method: The PEFT method strategy originally applied.
|
|
44
|
+
module: The PyTorch module to merge.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
method.merge(module)
|
d9d/peft/base.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
from typing import Generic, Self, TypeVar
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclasses.dataclass(slots=True)
|
|
12
|
+
class PeftInjectionResult:
|
|
13
|
+
"""
|
|
14
|
+
Encapsulates the result of injecting a PEFT method into a model.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
parameters_to_train: A list of parameters that should remain trainable.
|
|
18
|
+
load_state_mappers: A list of mappers required to load pre-trained weights into the modified structure.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
parameters_to_train: list[nn.Parameter]
|
|
22
|
+
load_state_mappers: list[ModelStateMapper]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
TConfig = TypeVar("TConfig", bound=BaseModel)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PeftMethod(abc.ABC, Generic[TConfig]):
|
|
29
|
+
"""
|
|
30
|
+
Abstract base class for all Parameter-Efficient Fine-Tuning methods.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def inject(self, module: nn.Module) -> PeftInjectionResult:
|
|
35
|
+
"""
|
|
36
|
+
Modifies the module in-place to apply the PEFT strategy.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
module: The PyTorch module to modify.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Result object containing trainable parameters and structure mappers.
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def merge(self, module: nn.Module):
|
|
48
|
+
"""
|
|
49
|
+
Merges the trained adapters back into the base model parameters.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
module: The PyTorch module to update.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
@abc.abstractmethod
|
|
59
|
+
def from_config(cls, config: TConfig) -> Self:
|
|
60
|
+
"""
|
|
61
|
+
Creates an instance of the method from a configuration object.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
config: The configuration object.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
An instance of the PeftMethod.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
...
|