d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
import typing
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
from torch.optim import Optimizer
|
|
7
|
+
|
|
8
|
+
from d9d.core.dist_context import DistributedContext
|
|
9
|
+
from d9d.core.protocol import LRSchedulerProtocol
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(kw_only=True)
|
|
13
|
+
class InitializeLRSchedulerContext:
|
|
14
|
+
"""
|
|
15
|
+
Context data required to initialize an LR scheduler.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
dist_context: The distributed context.
|
|
19
|
+
total_steps: The total number of training steps.
|
|
20
|
+
optimizer: The optimizer instance that the scheduler will control.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
dist_context: DistributedContext
|
|
24
|
+
total_steps: int
|
|
25
|
+
optimizer: Optimizer
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@typing.runtime_checkable
|
|
29
|
+
class LRSchedulerProvider(Protocol):
|
|
30
|
+
"""
|
|
31
|
+
Protocol for defining how Learning Rate schedulers are created.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def __call__(
|
|
36
|
+
self,
|
|
37
|
+
context: InitializeLRSchedulerContext
|
|
38
|
+
) -> LRSchedulerProtocol:
|
|
39
|
+
"""
|
|
40
|
+
Initializes the LR scheduler for a specific model pipeline stage.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
context: Context for this operation.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The instantiated LR scheduler adhering to the protocol.
|
|
47
|
+
"""
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
from typing import Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from d9d.core.dist_context import DistributedContext
|
|
8
|
+
from d9d.core.types import ScalarTree
|
|
9
|
+
from d9d.model_state.mapper import ModelStateMapper
|
|
10
|
+
from d9d.pipelining.api import PipelineStageInfo
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclasses.dataclass(kw_only=True)
|
|
14
|
+
class InitializeModelStageContext:
|
|
15
|
+
"""
|
|
16
|
+
Context data required for initializing a specific model pipeline stage.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
dist_context: The distributed execution context.
|
|
20
|
+
stage: Metadata describing the current pipeline stage being initialized.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
dist_context: DistributedContext
|
|
24
|
+
stage: PipelineStageInfo
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
TModel = TypeVar("TModel", bound=nn.Module)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclasses.dataclass(kw_only=True)
|
|
31
|
+
class InitializeModelStageResult(Generic[TModel]):
|
|
32
|
+
"""
|
|
33
|
+
The result of initializing a model stage.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
model: The PyTorch module.
|
|
37
|
+
state_mapper: The mapper defining how to load weights into this module.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
model: TModel
|
|
41
|
+
state_mapper: ModelStateMapper
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclasses.dataclass(kw_only=True)
|
|
45
|
+
class ParallelizeModelStageContext(Generic[TModel]):
|
|
46
|
+
"""
|
|
47
|
+
Context data required for horizontally parallelizing a model stage.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
dist_context: The distributed execution context.
|
|
51
|
+
stage: Metadata describing the current pipeline stage.
|
|
52
|
+
model: The PyTorch module to be parallelized.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
dist_context: DistributedContext
|
|
56
|
+
stage: PipelineStageInfo
|
|
57
|
+
model: TModel
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclasses.dataclass(kw_only=True)
|
|
61
|
+
class PrepareExportModelStageContext(Generic[TModel]):
|
|
62
|
+
"""
|
|
63
|
+
Context data required for preparing a model stage for export.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
dist_context: The distributed execution context.
|
|
67
|
+
model: The PyTorch module to be exported.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
dist_context: DistributedContext
|
|
71
|
+
model: TModel
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclasses.dataclass(kw_only=True)
|
|
75
|
+
class PrepareExportModelStageResult:
|
|
76
|
+
"""
|
|
77
|
+
The result of preparing a model stage for export.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
state_mapper: The mapper defining how model parameters map to disk storage.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
state_mapper: ModelStateMapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ModelProvider(abc.ABC, Generic[TModel]):
|
|
87
|
+
"""
|
|
88
|
+
Abstract interface for defining the lifecycle of a distributed model.
|
|
89
|
+
|
|
90
|
+
This provider handles initialization, parallelization (sharding/replication/etc), and export preparation
|
|
91
|
+
for models within the d9d framework.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
@abc.abstractmethod
|
|
95
|
+
def initialize_model_stage(
|
|
96
|
+
self,
|
|
97
|
+
context: InitializeModelStageContext
|
|
98
|
+
) -> InitializeModelStageResult[TModel]:
|
|
99
|
+
"""
|
|
100
|
+
Initializes the model architecture for a specific pipeline stage.
|
|
101
|
+
|
|
102
|
+
This method is responsible for constructing the `nn.Module` for the requested stage.
|
|
103
|
+
|
|
104
|
+
Construction occurs within a meta-device context; therefore, weights
|
|
105
|
+
should not be loaded directly here. Instead, a `ModelStateMapper` must be returned
|
|
106
|
+
to define how weights from a checkpoint map to the newly created module parameters.
|
|
107
|
+
|
|
108
|
+
This allows for architecture modifications, such as injecting LoRA adapters,
|
|
109
|
+
provided that the returned mapper reflects the new structure.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
context: Context for this operation.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Result of this operation.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
...
|
|
119
|
+
|
|
120
|
+
@abc.abstractmethod
|
|
121
|
+
def parallelize_model_stage(
|
|
122
|
+
self,
|
|
123
|
+
context: ParallelizeModelStageContext[TModel]
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
Converts the model parameters into distributed tensors (DTensors).
|
|
127
|
+
|
|
128
|
+
Implementations should modify the model in-place. This involves converting
|
|
129
|
+
standard parameters into DTensors by replicating or sharding them according
|
|
130
|
+
to the desired parallelism strategies.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
context: Context for this operation.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
@abc.abstractmethod
|
|
137
|
+
def prepare_export_model_stage(
|
|
138
|
+
self,
|
|
139
|
+
context: PrepareExportModelStageContext[TModel]
|
|
140
|
+
) -> PrepareExportModelStageResult:
|
|
141
|
+
"""
|
|
142
|
+
Prepares the state mapper required for saving the model to disk.
|
|
143
|
+
|
|
144
|
+
This methods defines how the current in-memory model structure maps back to the
|
|
145
|
+
serialized checkpoint format.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
context: Context for this operation.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Result of this operation.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def dump_hparams(self) -> ScalarTree:
|
|
155
|
+
"""
|
|
156
|
+
Exports hyperparameters associated with this model for logging.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A dictionary of hyperparameter names and values.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
return {}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
import typing
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
|
|
9
|
+
from d9d.core.dist_context import DistributedContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(kw_only=True)
|
|
13
|
+
class InitializeOptimizerStageContext:
|
|
14
|
+
"""
|
|
15
|
+
Context data required to initialize an optimizer.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
dist_context: The distributed context.
|
|
19
|
+
model: The model instance for which parameters will be optimized.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
dist_context: DistributedContext
|
|
23
|
+
model: nn.Module
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@typing.runtime_checkable
|
|
27
|
+
class OptimizerProvider(Protocol):
|
|
28
|
+
"""
|
|
29
|
+
Protocol for defining how optimizers are created for model pipeline stages.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def __call__(
|
|
34
|
+
self,
|
|
35
|
+
context: InitializeOptimizerStageContext
|
|
36
|
+
) -> Optimizer:
|
|
37
|
+
"""
|
|
38
|
+
Initializes the optimizer for a specific training stage.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
context: Context for this operation.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The instantiated PyTorch optimizer.
|
|
45
|
+
"""
|
d9d/loop/control/task.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
import typing
|
|
4
|
+
from collections.abc import Mapping
|
|
5
|
+
from typing import Any, Protocol
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
9
|
+
|
|
10
|
+
from d9d.core.dist_context import DistributedContext
|
|
11
|
+
from d9d.core.types import PyTree, ScalarTree
|
|
12
|
+
from d9d.pipelining.api import PipelineShardingSpec
|
|
13
|
+
|
|
14
|
+
if typing.TYPE_CHECKING:
|
|
15
|
+
from d9d.internals.pipeline_state import PipelineState
|
|
16
|
+
from d9d.loop.component import Stepper
|
|
17
|
+
from d9d.metric import Metric
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
TBatch = typing.TypeVar("TBatch", bound=PyTree)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclasses.dataclass(kw_only=True)
|
|
24
|
+
class BuildForwardInputsContext(typing.Generic[TBatch]):
|
|
25
|
+
"""
|
|
26
|
+
Context data to prepare inputs for the model forward pass.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
batch: The raw batch data loaded from the DataLoader object.
|
|
30
|
+
state: The current state of the pipeline. You can assign any data to this state object, and it will be
|
|
31
|
+
accessible during this pipeline step (e.g. when computing loss)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
batch: TBatch
|
|
35
|
+
state: "PipelineState"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass(kw_only=True)
|
|
39
|
+
class BuildForwardInputsResult:
|
|
40
|
+
"""
|
|
41
|
+
The result of processing the raw batch into model inputs.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
inputs: A dictionary of inputs that are passed to model pipeline as input data
|
|
45
|
+
(first stage only if using pipeline parallelism).
|
|
46
|
+
kwargs: A dictionary of keyword arguments passed to each pipeline stage.
|
|
47
|
+
pipeline_sharding_spec: A specification defining how inputs and kwargs should be split
|
|
48
|
+
into micro-batches for pipeline parallelism. If None, the framework assumes
|
|
49
|
+
standard behavior where all the non-scalar Tensors and lists are split by 0 dimension.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
inputs: dict[str, torch.Tensor]
|
|
53
|
+
kwargs: dict[str, Any]
|
|
54
|
+
pipeline_sharding_spec: PipelineShardingSpec | None = None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclasses.dataclass(kw_only=True)
|
|
58
|
+
class FinalizeContext:
|
|
59
|
+
"""Context data provided when the task is being finalized."""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class BaseTask(abc.ABC, Stateful, typing.Generic[TBatch]):
|
|
63
|
+
"""Abstract base class representing a unit of work (Task) in the training/inference loop."""
|
|
64
|
+
|
|
65
|
+
@abc.abstractmethod
|
|
66
|
+
def build_forward_inputs(self, ctx: BuildForwardInputsContext[TBatch]) -> BuildForwardInputsResult:
|
|
67
|
+
"""
|
|
68
|
+
Transforms raw data loaded from the DataLoader into arguments for the model.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
ctx: Context object.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Result object.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
def state_dict(self) -> dict[str, Any]:
|
|
80
|
+
"""
|
|
81
|
+
Returns the state dictionary for checkpointing this task.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A dictionary containing the task's state.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
return {}
|
|
88
|
+
|
|
89
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Restores the task's state from the provided dictionary.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
state_dict: The state dictionary to load.
|
|
95
|
+
"""
|
|
96
|
+
# do nothing by default
|
|
97
|
+
|
|
98
|
+
def finalize(self, ctx: FinalizeContext) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Performs cleanup or final actions when the task execution finishes.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
ctx: Context object.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclasses.dataclass(kw_only=True)
|
|
108
|
+
class ComputeLossContext:
|
|
109
|
+
"""
|
|
110
|
+
Context data provided to calculate the loss during training.
|
|
111
|
+
|
|
112
|
+
Attributes:
|
|
113
|
+
pipeline_results: The outputs returned by the model's forward pass.
|
|
114
|
+
state: The current state of the pipeline. You can assign any data to this state object, and it will be
|
|
115
|
+
accessible during this pipeline step (e.g. when calculating metrics)
|
|
116
|
+
stepper: Component tracking the current step.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
pipeline_results: Mapping[str, torch.Tensor]
|
|
120
|
+
state: "PipelineState"
|
|
121
|
+
stepper: "Stepper"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclasses.dataclass(kw_only=True)
|
|
125
|
+
class ComputeLossResult:
|
|
126
|
+
"""
|
|
127
|
+
The result of the loss computation.
|
|
128
|
+
|
|
129
|
+
Attributes:
|
|
130
|
+
loss: The scalar tensor representing the loss to be backpropagated.
|
|
131
|
+
loss_weight: The weight to apply to the loss (for synchronizing gradients using weighted mean).
|
|
132
|
+
None for 1.0.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
loss: torch.Tensor
|
|
136
|
+
loss_weight: torch.Tensor | None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclasses.dataclass(kw_only=True)
|
|
140
|
+
class CreateMetricsContext:
|
|
141
|
+
"""Context data provided to initialize metrics."""
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclasses.dataclass(kw_only=True)
|
|
145
|
+
class CreateMetricsResult:
|
|
146
|
+
"""
|
|
147
|
+
Result of metric initialization.
|
|
148
|
+
|
|
149
|
+
Attributes:
|
|
150
|
+
metrics: A dictionary mapping metric names to Metric instances.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
metrics: dict[str, "Metric"]
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@dataclasses.dataclass(kw_only=True)
|
|
157
|
+
class UpdateMetricsContext:
|
|
158
|
+
"""
|
|
159
|
+
Context data provided to update metrics after a step.
|
|
160
|
+
|
|
161
|
+
Attributes:
|
|
162
|
+
state: The current state of the pipeline.
|
|
163
|
+
metrics: The dictionary of metrics to be updated.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
state: "PipelineState"
|
|
167
|
+
metrics: Mapping[str, "Metric"]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class TrainTask(BaseTask, abc.ABC, typing.Generic[TBatch]):
|
|
171
|
+
"""Abstract base class for defining training-specific logic."""
|
|
172
|
+
|
|
173
|
+
@abc.abstractmethod
|
|
174
|
+
def compute_loss(self, ctx: ComputeLossContext) -> ComputeLossResult:
|
|
175
|
+
"""
|
|
176
|
+
Calculates the loss based on model outputs.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
ctx: Context object.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Result object.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
...
|
|
186
|
+
|
|
187
|
+
def create_metrics(self, ctx: CreateMetricsContext) -> CreateMetricsResult:
|
|
188
|
+
"""
|
|
189
|
+
Initializes metrics to be tracked during training.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
ctx: Context object.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Result object.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
return CreateMetricsResult(metrics={})
|
|
199
|
+
|
|
200
|
+
def update_metrics(self, ctx: UpdateMetricsContext):
|
|
201
|
+
"""
|
|
202
|
+
Updates the state of the metrics at the end of training step.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
ctx: Context object.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def dump_hparams(self) -> ScalarTree:
|
|
209
|
+
"""
|
|
210
|
+
Exports hyperparameters associated with this task for logging.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
A dictionary of hyperparameter names and values.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
return {}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@dataclasses.dataclass(kw_only=True)
|
|
220
|
+
class TrainTaskProviderContext:
|
|
221
|
+
"""
|
|
222
|
+
Context data provided to the factory creating a TrainTask.
|
|
223
|
+
|
|
224
|
+
Attributes:
|
|
225
|
+
dist_context: Information about the distributed environment.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
dist_context: DistributedContext
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@typing.runtime_checkable
|
|
232
|
+
class TrainTaskProvider(Protocol):
|
|
233
|
+
"""Protocol that creates a TrainTask instance."""
|
|
234
|
+
|
|
235
|
+
def __call__(self, ctx: TrainTaskProviderContext) -> TrainTask:
|
|
236
|
+
"""
|
|
237
|
+
Creates and returns a new TrainTask.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
ctx: Context object.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
An instantiated TrainTask.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
...
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dataclasses.dataclass(kw_only=True)
|
|
250
|
+
class ProcessOutputsContext:
|
|
251
|
+
"""
|
|
252
|
+
Context data provided to process outputs during inference.
|
|
253
|
+
|
|
254
|
+
Attributes:
|
|
255
|
+
outputs: The outputs returned by the model's forward pass.
|
|
256
|
+
state: The current state of the pipeline.
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
outputs: dict[str, torch.Tensor]
|
|
260
|
+
state: "PipelineState"
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class InferenceTask(BaseTask, abc.ABC, typing.Generic[TBatch]):
|
|
264
|
+
"""Abstract base class for defining inference-specific logic."""
|
|
265
|
+
|
|
266
|
+
@abc.abstractmethod
|
|
267
|
+
def process_outputs(self, ctx: ProcessOutputsContext):
|
|
268
|
+
"""
|
|
269
|
+
Processes the model outputs (e.g. saving to disk, decoding tokens).
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
ctx: Context containing the model outputs and pipeline state.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
...
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@dataclasses.dataclass(kw_only=True)
|
|
279
|
+
class InferenceTaskProviderContext:
|
|
280
|
+
"""
|
|
281
|
+
Context data provided to the factory creating an InferenceTask.
|
|
282
|
+
|
|
283
|
+
Attributes:
|
|
284
|
+
dist_context: Information about the distributed environment.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
dist_context: DistributedContext
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@typing.runtime_checkable
|
|
291
|
+
class InferenceTaskProvider(Protocol):
|
|
292
|
+
"""Protocol for a callable that creates an InferenceTask instance."""
|
|
293
|
+
|
|
294
|
+
def __call__(self, ctx: InferenceTaskProviderContext) -> InferenceTask:
|
|
295
|
+
"""
|
|
296
|
+
Creates and returns a new InferenceTask.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
ctx: Context providing distributed environment information.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
An instantiated InferenceTask.
|
|
303
|
+
"""
|
|
304
|
+
...
|