d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from re import Pattern
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FullTuneConfig(BaseModel):
|
|
8
|
+
"""
|
|
9
|
+
Configuration for Full Fine-Tuning.
|
|
10
|
+
|
|
11
|
+
Allows specifying which modules should be fully fine-tuned using regex patterns.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
kind: Discriminator field, always "full_tune".
|
|
15
|
+
module_name_pattern: Regular expression matching module names to unfreeze.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
kind: Literal["full_tune"] = "full_tune"
|
|
19
|
+
|
|
20
|
+
module_name_pattern: Pattern
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
from ..base import PeftInjectionResult, PeftMethod
|
|
6
|
+
from .config import FullTuneConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FullTune(PeftMethod[FullTuneConfig]):
|
|
10
|
+
"""
|
|
11
|
+
Implements Full Fine-Tuning as a 'PEFT' method.
|
|
12
|
+
|
|
13
|
+
Instead of injecting adapters, this method simply identifies existing parameters
|
|
14
|
+
that match the configuration pattern and marks them for training.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, config: FullTuneConfig):
|
|
18
|
+
"""
|
|
19
|
+
Constructs a FullTune object.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
config: Configuration defining the module name patterns to fine-tune.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
self._config = config
|
|
26
|
+
|
|
27
|
+
def inject(self, module: nn.Module) -> PeftInjectionResult:
|
|
28
|
+
params_to_train = []
|
|
29
|
+
|
|
30
|
+
for mod_name, mod in module.named_modules():
|
|
31
|
+
is_applicable = self._config.module_name_pattern.fullmatch(mod_name)
|
|
32
|
+
|
|
33
|
+
if is_applicable:
|
|
34
|
+
params_to_train.extend(mod.parameters())
|
|
35
|
+
|
|
36
|
+
return PeftInjectionResult(
|
|
37
|
+
parameters_to_train=params_to_train,
|
|
38
|
+
load_state_mappers=[]
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def merge(self, module: nn.Module):
|
|
42
|
+
pass # do nothing here
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_config(cls, config: FullTuneConfig) -> Self:
|
|
46
|
+
return cls(config)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Package for Low-Rank Adaptation (LoRA) implementation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .config import LoRAConfig, LoRAParameters
|
|
6
|
+
from .layer import LoRAGroupedLinear, LoRALinear
|
|
7
|
+
from .method import LoRA
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"LoRA",
|
|
11
|
+
"LoRAConfig",
|
|
12
|
+
"LoRAGroupedLinear",
|
|
13
|
+
"LoRALinear",
|
|
14
|
+
"LoRAParameters"
|
|
15
|
+
]
|
d9d/peft/lora/config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from re import Pattern
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LoRAParameters(BaseModel):
|
|
8
|
+
"""
|
|
9
|
+
Hyperparameters for LoRA layers.
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
r: Rank of the low-rank adaptation matrices.
|
|
13
|
+
alpha: Scaling factor for the learned weights.
|
|
14
|
+
dropout: Dropout probability for the input to LoRA layers.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
r: int
|
|
18
|
+
alpha: int
|
|
19
|
+
dropout: float
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LoRAConfig(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
Configuration for LoRA application.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
kind: Discriminator field, always "lora".
|
|
28
|
+
module_name_pattern: Regular expression matching module names to wrap with LoRA.
|
|
29
|
+
params: Hyperparameters for the LoRA layers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
kind: Literal["lora"] = "lora"
|
|
33
|
+
|
|
34
|
+
module_name_pattern: Pattern
|
|
35
|
+
params: LoRAParameters
|
d9d/peft/lora/layer.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
from d9d.module.block.moe import GroupedLinear
|
|
5
|
+
|
|
6
|
+
from .config import LoRAParameters
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LoRALinear(nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
A LoRA wrapper around a standard PyTorch Linear layer.
|
|
12
|
+
|
|
13
|
+
Wraps a base linear layer and adds low-rank adaptation matrices A and B.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
lora_A: The A matrix (in_features -> r).
|
|
17
|
+
lora_B: The B matrix (r -> out_features).
|
|
18
|
+
base: The original base Linear layer.
|
|
19
|
+
dropout: Scaling dropout layer.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
base_layer: nn.Linear,
|
|
25
|
+
params: LoRAParameters
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Constructs a LoRALinear layer.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
base_layer: The original Linear layer to wrap.
|
|
32
|
+
params: LoRA hyperparameters (r, alpha, dropout).
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If the base layer has a bias (currently unsupported).
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.lora_A = nn.Linear(
|
|
40
|
+
base_layer.in_features, params.r, bias=False,
|
|
41
|
+
device=base_layer.weight.device,
|
|
42
|
+
dtype=base_layer.weight.dtype
|
|
43
|
+
)
|
|
44
|
+
self.lora_B = nn.Linear(
|
|
45
|
+
params.r, base_layer.out_features, bias=False,
|
|
46
|
+
device=base_layer.weight.device,
|
|
47
|
+
dtype=base_layer.weight.dtype
|
|
48
|
+
)
|
|
49
|
+
self.base = base_layer
|
|
50
|
+
|
|
51
|
+
if base_layer.bias is not None:
|
|
52
|
+
raise ValueError("LoRA is unsupported with biased linear layers")
|
|
53
|
+
|
|
54
|
+
self.dropout: nn.Dropout = nn.Dropout(params.dropout)
|
|
55
|
+
|
|
56
|
+
self._scale: float = params.alpha / params.r
|
|
57
|
+
|
|
58
|
+
self.reset_parameters()
|
|
59
|
+
|
|
60
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
"""
|
|
62
|
+
Takes input tensor, computes base output and LoRA adaptation, and returns the sum.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
x: Input tensor.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
The output of base(x) + scale * (B @ A @ dropout(x)).
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
base_x = self.base(x)
|
|
72
|
+
adapt_x = self._scale * self.lora_B(self.lora_A(self.dropout(x)))
|
|
73
|
+
return base_x + adapt_x
|
|
74
|
+
|
|
75
|
+
@torch.no_grad()
|
|
76
|
+
def merge_with_base_(self) -> nn.Linear:
|
|
77
|
+
"""
|
|
78
|
+
Collapse the LoRA weights into the base linear layer.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The modified base linear layer with updated weights.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
mod = self.base
|
|
85
|
+
mod.weight.data += (self.lora_B.weight.data @ self.lora_A.weight.data) * self._scale
|
|
86
|
+
return mod
|
|
87
|
+
|
|
88
|
+
def reset_parameters(self):
|
|
89
|
+
"""
|
|
90
|
+
Resets LoRA parameters. A is random, B is zeroed.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
self.lora_A.reset_parameters()
|
|
94
|
+
nn.init.zeros_(self.lora_B.weight)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class LoRAGroupedLinear(nn.Module):
|
|
98
|
+
"""
|
|
99
|
+
A LoRA wrapper around a GroupedLinear layer (commonly used in MoE or grouped query attention).
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
lora_A: The A matrix (grouped linear).
|
|
103
|
+
lora_B: The B matrix (grouped linear).
|
|
104
|
+
base: The original base GroupedLinear layer.
|
|
105
|
+
dropout: Scaling dropout layer.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
base_layer: GroupedLinear,
|
|
111
|
+
params: LoRAParameters
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Constructs a LoRAGroupedLinear layer.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
base_layer: The original GroupedLinear layer to wrap.
|
|
118
|
+
params: LoRA hyperparameters.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.lora_A = GroupedLinear(
|
|
123
|
+
base_layer.n_groups, base_layer.in_features, params.r,
|
|
124
|
+
device=base_layer.weight.device,
|
|
125
|
+
dtype=base_layer.weight.dtype
|
|
126
|
+
)
|
|
127
|
+
self.lora_B = GroupedLinear(
|
|
128
|
+
base_layer.n_groups,
|
|
129
|
+
params.r,
|
|
130
|
+
base_layer.out_features,
|
|
131
|
+
device=base_layer.weight.device,
|
|
132
|
+
dtype=base_layer.weight.dtype
|
|
133
|
+
)
|
|
134
|
+
self.base = base_layer
|
|
135
|
+
|
|
136
|
+
self.dropout = nn.Dropout(params.dropout)
|
|
137
|
+
|
|
138
|
+
self._scale = params.alpha / params.r
|
|
139
|
+
|
|
140
|
+
self.reset_parameters()
|
|
141
|
+
|
|
142
|
+
def forward(self, x: torch.Tensor, x_groups: torch.Tensor) -> torch.Tensor:
|
|
143
|
+
"""
|
|
144
|
+
Computes forward pass for grouped inputs.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
x: Input tensor.
|
|
148
|
+
x_groups: A tensor indicating group indices for each input.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Combined output of base and LoRA path.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
base_x = self.base(x, x_groups)
|
|
155
|
+
adapt_x = self._scale * self.lora_B(self.lora_A(self.dropout(x), x_groups), x_groups)
|
|
156
|
+
return base_x + adapt_x
|
|
157
|
+
|
|
158
|
+
@torch.no_grad()
|
|
159
|
+
def merge_with_base_(self) -> GroupedLinear:
|
|
160
|
+
"""
|
|
161
|
+
Collapse the LoRA weights into the base GroupedLinear layer.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
The modified GroupedLinear layer.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
mod = self.base
|
|
168
|
+
mod.weight.data += (torch.bmm(self.lora_A.weight.data, self.lora_B.weight.data)) * self._scale
|
|
169
|
+
return mod
|
|
170
|
+
|
|
171
|
+
def reset_parameters(self):
|
|
172
|
+
"""
|
|
173
|
+
Resets LoRA parameters. A is random, B is zeroed.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
self.lora_A.reset_parameters()
|
|
177
|
+
nn.init.zeros_(self.lora_B.weight)
|
d9d/peft/lora/method.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
7
|
+
from d9d.model_state.mapper.leaf import ModelStateMapperRename
|
|
8
|
+
from d9d.module.block.moe import GroupedLinear
|
|
9
|
+
|
|
10
|
+
from ..base import PeftInjectionResult, PeftMethod
|
|
11
|
+
from .config import LoRAConfig
|
|
12
|
+
from .layer import LoRAGroupedLinear, LoRALinear
|
|
13
|
+
|
|
14
|
+
_CAN_APPLY_MODULES = (nn.Linear, GroupedLinear)
|
|
15
|
+
_LORA_MODULES = (LoRALinear, LoRAGroupedLinear)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def named_modules_without_lora(
|
|
19
|
+
module: nn.Module,
|
|
20
|
+
memo: set[nn.Module] | None = None,
|
|
21
|
+
prefix: str = "",
|
|
22
|
+
remove_duplicate: bool = True
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Yields named modules, skipping submodules that are already LoRA layers.
|
|
26
|
+
|
|
27
|
+
This prevents recursively re-injecting LoRA into an already wrapped layer during
|
|
28
|
+
traversal.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
module: The root module to traverse.
|
|
32
|
+
memo: Set of processed modules to avoid duplicates.
|
|
33
|
+
prefix: Current namespace prefix.
|
|
34
|
+
remove_duplicate: Whether to skip modules seen in memo.
|
|
35
|
+
|
|
36
|
+
Yields:
|
|
37
|
+
Tuple of (name, module).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
if isinstance(module, _LORA_MODULES):
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
if memo is None:
|
|
44
|
+
memo = set()
|
|
45
|
+
if module in memo:
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
if remove_duplicate:
|
|
49
|
+
memo.add(module)
|
|
50
|
+
|
|
51
|
+
yield prefix, module
|
|
52
|
+
|
|
53
|
+
for name, submodule in module.named_children():
|
|
54
|
+
if submodule is None:
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
submodule_prefix = prefix + ("." if prefix else "") + name
|
|
58
|
+
yield from named_modules_without_lora(submodule, memo, submodule_prefix, remove_duplicate)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LoRA(PeftMethod[LoRAConfig]):
|
|
62
|
+
"""
|
|
63
|
+
Implements the Low-Rank Adaptation (LoRA) injection strategy.
|
|
64
|
+
|
|
65
|
+
It scans the module structure for `nn.Linear` or `GroupedLinear` layers matching
|
|
66
|
+
the configured name pattern. Matched layers are replaced with LoRA wrappers.
|
|
67
|
+
|
|
68
|
+
It also generates `ModelStateMapperRename` objects. Since the original weight
|
|
69
|
+
`layer.weight` is now at `layer.base.weight` inside the wrapper, the mapper
|
|
70
|
+
ensures that loading a standard checkpoint still works by redirecting the key.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, config: LoRAConfig):
|
|
74
|
+
"""
|
|
75
|
+
Constructs a LoRA method.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
config: LoRA configuration containing patterns and hyperparameters.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
self._config = config
|
|
82
|
+
|
|
83
|
+
def inject(self, module: nn.Module) -> PeftInjectionResult:
|
|
84
|
+
params_to_train: list[nn.Parameter] = []
|
|
85
|
+
state_mappers: list[ModelStateMapper] = []
|
|
86
|
+
|
|
87
|
+
for mod_name, mod in named_modules_without_lora(module):
|
|
88
|
+
if not isinstance(mod, _CAN_APPLY_MODULES):
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if not self._config.module_name_pattern.fullmatch(mod_name):
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
lora_mod: LoRALinear | LoRAGroupedLinear
|
|
95
|
+
if isinstance(mod, nn.Linear):
|
|
96
|
+
lora_mod = LoRALinear(mod, self._config.params)
|
|
97
|
+
elif isinstance(mod, GroupedLinear):
|
|
98
|
+
lora_mod = LoRAGroupedLinear(mod, self._config.params)
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Unknown layer {type(mod)} for LoRA")
|
|
101
|
+
|
|
102
|
+
params_to_train.extend(lora_mod.lora_A.parameters())
|
|
103
|
+
params_to_train.extend(lora_mod.lora_B.parameters())
|
|
104
|
+
|
|
105
|
+
state_mappers.append(ModelStateMapperRename(
|
|
106
|
+
name_from=f"{mod_name}.weight",
|
|
107
|
+
name_to=f"{mod_name}.base.weight"
|
|
108
|
+
))
|
|
109
|
+
|
|
110
|
+
module.set_submodule(mod_name, lora_mod)
|
|
111
|
+
|
|
112
|
+
return PeftInjectionResult(
|
|
113
|
+
parameters_to_train=params_to_train,
|
|
114
|
+
load_state_mappers=state_mappers
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def merge(self, module: nn.Module):
|
|
118
|
+
for mod_name, mod in module.named_modules():
|
|
119
|
+
if not isinstance(mod, _LORA_MODULES):
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
if not self._config.module_name_pattern.fullmatch(mod_name):
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
with torch.no_grad():
|
|
126
|
+
orig_mod = mod.merge_with_base_()
|
|
127
|
+
|
|
128
|
+
module.set_submodule(mod_name, orig_mod)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_config(cls, config: LoRAConfig) -> Self:
|
|
132
|
+
return cls(config)
|
|
File without changes
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipelining API that is intended to be accessible by end user.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .module import (
|
|
6
|
+
ModuleSupportsPipelining,
|
|
7
|
+
PipelineStageInfo,
|
|
8
|
+
distribute_layers_for_pipeline_stage,
|
|
9
|
+
)
|
|
10
|
+
from .schedule import PipelineSchedule
|
|
11
|
+
from .sharding import PipelineShardingSpec
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ModuleSupportsPipelining",
|
|
15
|
+
"PipelineSchedule",
|
|
16
|
+
"PipelineShardingSpec",
|
|
17
|
+
"PipelineStageInfo",
|
|
18
|
+
"distribute_layers_for_pipeline_stage"
|
|
19
|
+
]
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass
|
|
8
|
+
class PipelineStageInfo:
|
|
9
|
+
"""
|
|
10
|
+
Holds information about the current position within the distributed pipeline.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
current_stage: The 0-based index of the current pipeline stage.
|
|
14
|
+
num_stages: The total number of stages in the pipeline.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
current_stage: int
|
|
18
|
+
num_stages: int
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def is_current_stage_first(self) -> bool:
|
|
22
|
+
"""
|
|
23
|
+
Determines if this is the first stage in the pipeline.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
True if current_stage is 0.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
return self.current_stage == 0
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def is_current_stage_last(self) -> bool:
|
|
33
|
+
"""
|
|
34
|
+
Determines if this is the last stage in the pipeline.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
True if current_stage is the last index.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
return self.current_stage == self.num_stages - 1
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def distribute_layers_for_pipeline_stage(
|
|
44
|
+
num_layers: int,
|
|
45
|
+
num_virtual_layers_pre: int,
|
|
46
|
+
num_virtual_layers_post: int,
|
|
47
|
+
stage: PipelineStageInfo
|
|
48
|
+
) -> tuple[int, int]:
|
|
49
|
+
"""
|
|
50
|
+
Calculates the layer index range for a specific pipeline stage.
|
|
51
|
+
|
|
52
|
+
This function distributes a given number of layers across multiple pipeline
|
|
53
|
+
stages as evenly as possible. It accounts for additional, non-layer
|
|
54
|
+
computational load on the first and last stages (e.g., embeddings and the
|
|
55
|
+
LM head) by using the concept of 'virtual layers' to reserve capacity.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
num_layers: The total number of primary model layers to be distributed
|
|
59
|
+
(e.g., the transformer blocks).
|
|
60
|
+
num_virtual_layers_pre: The number of 'virtual' layers representing the
|
|
61
|
+
computational cost of modules on the *first* stage, before the main
|
|
62
|
+
layers (e.g., token and positional embeddings).
|
|
63
|
+
num_virtual_layers_post: The number of 'virtual' layers representing the
|
|
64
|
+
computational cost of modules on the *last* stage, after the main
|
|
65
|
+
layers (e.g., the final layer normalization and LM head).
|
|
66
|
+
stage: An object containing total stages and current stage index.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A tuple (start_index, end_index), representing the slice of layers for
|
|
70
|
+
the given stage. The start_index is inclusive and the end_index is
|
|
71
|
+
exclusive.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: If the pipeline configuration results in a stage having zero
|
|
75
|
+
or negative layers assigned (pipeline too long for the model size).
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
num_layers_virtual = num_layers + num_virtual_layers_pre + num_virtual_layers_post
|
|
79
|
+
|
|
80
|
+
base_layers_per_stage = num_layers_virtual // stage.num_stages
|
|
81
|
+
extra_layers = num_layers_virtual % stage.num_stages
|
|
82
|
+
|
|
83
|
+
layer_count_per_stage = []
|
|
84
|
+
|
|
85
|
+
for proposed_stage_i in range(stage.num_stages):
|
|
86
|
+
proposed_stage = PipelineStageInfo(num_stages=stage.num_stages, current_stage=proposed_stage_i)
|
|
87
|
+
layers = base_layers_per_stage + 1 if proposed_stage_i < extra_layers else base_layers_per_stage
|
|
88
|
+
|
|
89
|
+
adjustment = 0
|
|
90
|
+
if proposed_stage.is_current_stage_first:
|
|
91
|
+
adjustment += num_virtual_layers_pre
|
|
92
|
+
if proposed_stage.is_current_stage_last:
|
|
93
|
+
adjustment += num_virtual_layers_post
|
|
94
|
+
|
|
95
|
+
actual_layers = layers - adjustment
|
|
96
|
+
|
|
97
|
+
if actual_layers <= 0:
|
|
98
|
+
raise ValueError(f"Tried to distribute layers, but got {actual_layers} on "
|
|
99
|
+
f"stage {proposed_stage.current_stage}. Perhaps the pipeline is too long for this model?")
|
|
100
|
+
|
|
101
|
+
layer_count_per_stage.append(actual_layers)
|
|
102
|
+
|
|
103
|
+
start_layer_id = sum(layer_count_per_stage[:stage.current_stage])
|
|
104
|
+
num_layers_in_stage = layer_count_per_stage[stage.current_stage]
|
|
105
|
+
|
|
106
|
+
return start_layer_id, start_layer_id + num_layers_in_stage
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@typing.runtime_checkable
|
|
110
|
+
class ModuleSupportsPipelining(typing.Protocol):
|
|
111
|
+
"""
|
|
112
|
+
Protocol for modules that support pipeline parallelism metadata inference.
|
|
113
|
+
|
|
114
|
+
Classes implementing this protocol enable the framework to pre-calculate
|
|
115
|
+
tensor shapes and types required for inter-stage communication (p2p)
|
|
116
|
+
without executing the full forward pass.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def infer_stage_inputs_from_pipeline_inputs(
|
|
120
|
+
self, inputs: dict[str, torch.Tensor], n_microbatches: int
|
|
121
|
+
) -> dict[str, torch.Tensor]:
|
|
122
|
+
"""
|
|
123
|
+
Infers the input tensors metadata for the current pipeline stage based on global batch inputs.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
inputs: Global inputs for the pipeline.
|
|
127
|
+
n_microbatches: Number of microbatches the global batch is split into.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dictionary of input tensors expected by this specific stage locally.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
...
|
|
134
|
+
|
|
135
|
+
def infer_stage_outputs_from_pipeline_inputs(
|
|
136
|
+
self, inputs: dict[str, torch.Tensor], n_microbatches: int
|
|
137
|
+
) -> dict[str, torch.Tensor]:
|
|
138
|
+
"""
|
|
139
|
+
Infers the output tensors metadata for the current pipeline stage based on global batch inputs.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
inputs: Global inputs for the pipeline (typically a batch).
|
|
143
|
+
n_microbatches: Number of microbatches the global batch is split into.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Dictionary of output tensors produced by this specific stage locally.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
...
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .sharding import PipelineShardingSpec
|
|
7
|
+
|
|
8
|
+
# TODO: feature - support any PyTrees as pipeline parameters
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PipelineSchedule(abc.ABC):
|
|
12
|
+
"""Abstract base class defining the interface for pipeline execution schedules."""
|
|
13
|
+
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
def configure_buffers(
|
|
16
|
+
self,
|
|
17
|
+
inputs: dict[str, torch.Tensor],
|
|
18
|
+
kwargs: dict[str, Any],
|
|
19
|
+
sharding_spec: PipelineShardingSpec | None
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Configures internal state and buffers based on input shapes.
|
|
23
|
+
|
|
24
|
+
This method allows the schedule to pre-allocate memory or setup sharding
|
|
25
|
+
specifications based on the structure of the input data before execution begins.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
inputs: A dictionary of input tensors.
|
|
29
|
+
kwargs: A dictionary of keyword arguments.
|
|
30
|
+
sharding_spec: A specification defining how inputs and kwargs should be split
|
|
31
|
+
into micro-batches. If None, assumes standard split-by-zero-dim behavior.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
|
|
38
|
+
"""
|
|
39
|
+
Executes a single pipeline step using the provided inputs.
|
|
40
|
+
|
|
41
|
+
This typically involves distributing inputs across microbatches,
|
|
42
|
+
executing forward and backward passes according to the specific schedule logic,
|
|
43
|
+
and handling communications between stages.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
inputs: A dictionary of global input tensors.
|
|
47
|
+
kwargs: A dictionary of global keyword arguments.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
...
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .config import (
|
|
2
|
+
AnyPipelineScheduleConfig,
|
|
3
|
+
PipelineSchedule1F1BConfig,
|
|
4
|
+
PipelineScheduleDualPipeVConfig,
|
|
5
|
+
PipelineScheduleGPipeConfig,
|
|
6
|
+
PipelineScheduleInferenceConfig,
|
|
7
|
+
PipelineScheduleLoopedBFSConfig,
|
|
8
|
+
PipelineScheduleZeroBubbleVConfig,
|
|
9
|
+
)
|
|
10
|
+
from .factory import build_schedule
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"AnyPipelineScheduleConfig",
|
|
14
|
+
"PipelineSchedule1F1BConfig",
|
|
15
|
+
"PipelineScheduleDualPipeVConfig",
|
|
16
|
+
"PipelineScheduleGPipeConfig",
|
|
17
|
+
"PipelineScheduleInferenceConfig",
|
|
18
|
+
"PipelineScheduleLoopedBFSConfig",
|
|
19
|
+
"PipelineScheduleZeroBubbleVConfig",
|
|
20
|
+
"build_schedule"
|
|
21
|
+
]
|