d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,47 @@
1
+ import abc
2
+ import dataclasses
3
+ import typing
4
+ from typing import Protocol
5
+
6
+ from torch.optim import Optimizer
7
+
8
+ from d9d.core.dist_context import DistributedContext
9
+ from d9d.core.protocol import LRSchedulerProtocol
10
+
11
+
12
+ @dataclasses.dataclass(kw_only=True)
13
+ class InitializeLRSchedulerContext:
14
+ """
15
+ Context data required to initialize an LR scheduler.
16
+
17
+ Attributes:
18
+ dist_context: The distributed context.
19
+ total_steps: The total number of training steps.
20
+ optimizer: The optimizer instance that the scheduler will control.
21
+ """
22
+
23
+ dist_context: DistributedContext
24
+ total_steps: int
25
+ optimizer: Optimizer
26
+
27
+
28
+ @typing.runtime_checkable
29
+ class LRSchedulerProvider(Protocol):
30
+ """
31
+ Protocol for defining how Learning Rate schedulers are created.
32
+ """
33
+
34
+ @abc.abstractmethod
35
+ def __call__(
36
+ self,
37
+ context: InitializeLRSchedulerContext
38
+ ) -> LRSchedulerProtocol:
39
+ """
40
+ Initializes the LR scheduler for a specific model pipeline stage.
41
+
42
+ Args:
43
+ context: Context for this operation.
44
+
45
+ Returns:
46
+ The instantiated LR scheduler adhering to the protocol.
47
+ """
@@ -0,0 +1,162 @@
1
+ import abc
2
+ import dataclasses
3
+ from typing import Generic, TypeVar
4
+
5
+ from torch import nn
6
+
7
+ from d9d.core.dist_context import DistributedContext
8
+ from d9d.core.types import ScalarTree
9
+ from d9d.model_state.mapper import ModelStateMapper
10
+ from d9d.pipelining.api import PipelineStageInfo
11
+
12
+
13
+ @dataclasses.dataclass(kw_only=True)
14
+ class InitializeModelStageContext:
15
+ """
16
+ Context data required for initializing a specific model pipeline stage.
17
+
18
+ Attributes:
19
+ dist_context: The distributed execution context.
20
+ stage: Metadata describing the current pipeline stage being initialized.
21
+ """
22
+
23
+ dist_context: DistributedContext
24
+ stage: PipelineStageInfo
25
+
26
+
27
+ TModel = TypeVar("TModel", bound=nn.Module)
28
+
29
+
30
+ @dataclasses.dataclass(kw_only=True)
31
+ class InitializeModelStageResult(Generic[TModel]):
32
+ """
33
+ The result of initializing a model stage.
34
+
35
+ Attributes:
36
+ model: The PyTorch module.
37
+ state_mapper: The mapper defining how to load weights into this module.
38
+ """
39
+
40
+ model: TModel
41
+ state_mapper: ModelStateMapper
42
+
43
+
44
+ @dataclasses.dataclass(kw_only=True)
45
+ class ParallelizeModelStageContext(Generic[TModel]):
46
+ """
47
+ Context data required for horizontally parallelizing a model stage.
48
+
49
+ Attributes:
50
+ dist_context: The distributed execution context.
51
+ stage: Metadata describing the current pipeline stage.
52
+ model: The PyTorch module to be parallelized.
53
+ """
54
+
55
+ dist_context: DistributedContext
56
+ stage: PipelineStageInfo
57
+ model: TModel
58
+
59
+
60
+ @dataclasses.dataclass(kw_only=True)
61
+ class PrepareExportModelStageContext(Generic[TModel]):
62
+ """
63
+ Context data required for preparing a model stage for export.
64
+
65
+ Attributes:
66
+ dist_context: The distributed execution context.
67
+ model: The PyTorch module to be exported.
68
+ """
69
+
70
+ dist_context: DistributedContext
71
+ model: TModel
72
+
73
+
74
+ @dataclasses.dataclass(kw_only=True)
75
+ class PrepareExportModelStageResult:
76
+ """
77
+ The result of preparing a model stage for export.
78
+
79
+ Attributes:
80
+ state_mapper: The mapper defining how model parameters map to disk storage.
81
+ """
82
+
83
+ state_mapper: ModelStateMapper
84
+
85
+
86
+ class ModelProvider(abc.ABC, Generic[TModel]):
87
+ """
88
+ Abstract interface for defining the lifecycle of a distributed model.
89
+
90
+ This provider handles initialization, parallelization (sharding/replication/etc), and export preparation
91
+ for models within the d9d framework.
92
+ """
93
+
94
+ @abc.abstractmethod
95
+ def initialize_model_stage(
96
+ self,
97
+ context: InitializeModelStageContext
98
+ ) -> InitializeModelStageResult[TModel]:
99
+ """
100
+ Initializes the model architecture for a specific pipeline stage.
101
+
102
+ This method is responsible for constructing the `nn.Module` for the requested stage.
103
+
104
+ Construction occurs within a meta-device context; therefore, weights
105
+ should not be loaded directly here. Instead, a `ModelStateMapper` must be returned
106
+ to define how weights from a checkpoint map to the newly created module parameters.
107
+
108
+ This allows for architecture modifications, such as injecting LoRA adapters,
109
+ provided that the returned mapper reflects the new structure.
110
+
111
+ Args:
112
+ context: Context for this operation.
113
+
114
+ Returns:
115
+ Result of this operation.
116
+ """
117
+
118
+ ...
119
+
120
+ @abc.abstractmethod
121
+ def parallelize_model_stage(
122
+ self,
123
+ context: ParallelizeModelStageContext[TModel]
124
+ ):
125
+ """
126
+ Converts the model parameters into distributed tensors (DTensors).
127
+
128
+ Implementations should modify the model in-place. This involves converting
129
+ standard parameters into DTensors by replicating or sharding them according
130
+ to the desired parallelism strategies.
131
+
132
+ Args:
133
+ context: Context for this operation.
134
+ """
135
+
136
+ @abc.abstractmethod
137
+ def prepare_export_model_stage(
138
+ self,
139
+ context: PrepareExportModelStageContext[TModel]
140
+ ) -> PrepareExportModelStageResult:
141
+ """
142
+ Prepares the state mapper required for saving the model to disk.
143
+
144
+ This methods defines how the current in-memory model structure maps back to the
145
+ serialized checkpoint format.
146
+
147
+ Args:
148
+ context: Context for this operation.
149
+
150
+ Returns:
151
+ Result of this operation.
152
+ """
153
+
154
+ def dump_hparams(self) -> ScalarTree:
155
+ """
156
+ Exports hyperparameters associated with this model for logging.
157
+
158
+ Returns:
159
+ A dictionary of hyperparameter names and values.
160
+ """
161
+
162
+ return {}
@@ -0,0 +1,45 @@
1
+ import abc
2
+ import dataclasses
3
+ import typing
4
+ from typing import Protocol
5
+
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+
9
+ from d9d.core.dist_context import DistributedContext
10
+
11
+
12
+ @dataclasses.dataclass(kw_only=True)
13
+ class InitializeOptimizerStageContext:
14
+ """
15
+ Context data required to initialize an optimizer.
16
+
17
+ Attributes:
18
+ dist_context: The distributed context.
19
+ model: The model instance for which parameters will be optimized.
20
+ """
21
+
22
+ dist_context: DistributedContext
23
+ model: nn.Module
24
+
25
+
26
+ @typing.runtime_checkable
27
+ class OptimizerProvider(Protocol):
28
+ """
29
+ Protocol for defining how optimizers are created for model pipeline stages.
30
+ """
31
+
32
+ @abc.abstractmethod
33
+ def __call__(
34
+ self,
35
+ context: InitializeOptimizerStageContext
36
+ ) -> Optimizer:
37
+ """
38
+ Initializes the optimizer for a specific training stage.
39
+
40
+ Args:
41
+ context: Context for this operation.
42
+
43
+ Returns:
44
+ The instantiated PyTorch optimizer.
45
+ """
@@ -0,0 +1,304 @@
1
+ import abc
2
+ import dataclasses
3
+ import typing
4
+ from collections.abc import Mapping
5
+ from typing import Any, Protocol
6
+
7
+ import torch
8
+ from torch.distributed.checkpoint.stateful import Stateful
9
+
10
+ from d9d.core.dist_context import DistributedContext
11
+ from d9d.core.types import PyTree, ScalarTree
12
+ from d9d.pipelining.api import PipelineShardingSpec
13
+
14
+ if typing.TYPE_CHECKING:
15
+ from d9d.internals.pipeline_state import PipelineState
16
+ from d9d.loop.component import Stepper
17
+ from d9d.metric import Metric
18
+
19
+
20
+ TBatch = typing.TypeVar("TBatch", bound=PyTree)
21
+
22
+
23
+ @dataclasses.dataclass(kw_only=True)
24
+ class BuildForwardInputsContext(typing.Generic[TBatch]):
25
+ """
26
+ Context data to prepare inputs for the model forward pass.
27
+
28
+ Attributes:
29
+ batch: The raw batch data loaded from the DataLoader object.
30
+ state: The current state of the pipeline. You can assign any data to this state object, and it will be
31
+ accessible during this pipeline step (e.g. when computing loss)
32
+ """
33
+
34
+ batch: TBatch
35
+ state: "PipelineState"
36
+
37
+
38
+ @dataclasses.dataclass(kw_only=True)
39
+ class BuildForwardInputsResult:
40
+ """
41
+ The result of processing the raw batch into model inputs.
42
+
43
+ Attributes:
44
+ inputs: A dictionary of inputs that are passed to model pipeline as input data
45
+ (first stage only if using pipeline parallelism).
46
+ kwargs: A dictionary of keyword arguments passed to each pipeline stage.
47
+ pipeline_sharding_spec: A specification defining how inputs and kwargs should be split
48
+ into micro-batches for pipeline parallelism. If None, the framework assumes
49
+ standard behavior where all the non-scalar Tensors and lists are split by 0 dimension.
50
+ """
51
+
52
+ inputs: dict[str, torch.Tensor]
53
+ kwargs: dict[str, Any]
54
+ pipeline_sharding_spec: PipelineShardingSpec | None = None
55
+
56
+
57
+ @dataclasses.dataclass(kw_only=True)
58
+ class FinalizeContext:
59
+ """Context data provided when the task is being finalized."""
60
+
61
+
62
+ class BaseTask(abc.ABC, Stateful, typing.Generic[TBatch]):
63
+ """Abstract base class representing a unit of work (Task) in the training/inference loop."""
64
+
65
+ @abc.abstractmethod
66
+ def build_forward_inputs(self, ctx: BuildForwardInputsContext[TBatch]) -> BuildForwardInputsResult:
67
+ """
68
+ Transforms raw data loaded from the DataLoader into arguments for the model.
69
+
70
+ Args:
71
+ ctx: Context object.
72
+
73
+ Returns:
74
+ Result object.
75
+ """
76
+
77
+ ...
78
+
79
+ def state_dict(self) -> dict[str, Any]:
80
+ """
81
+ Returns the state dictionary for checkpointing this task.
82
+
83
+ Returns:
84
+ A dictionary containing the task's state.
85
+ """
86
+
87
+ return {}
88
+
89
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
90
+ """
91
+ Restores the task's state from the provided dictionary.
92
+
93
+ Args:
94
+ state_dict: The state dictionary to load.
95
+ """
96
+ # do nothing by default
97
+
98
+ def finalize(self, ctx: FinalizeContext) -> None:
99
+ """
100
+ Performs cleanup or final actions when the task execution finishes.
101
+
102
+ Args:
103
+ ctx: Context object.
104
+ """
105
+
106
+
107
+ @dataclasses.dataclass(kw_only=True)
108
+ class ComputeLossContext:
109
+ """
110
+ Context data provided to calculate the loss during training.
111
+
112
+ Attributes:
113
+ pipeline_results: The outputs returned by the model's forward pass.
114
+ state: The current state of the pipeline. You can assign any data to this state object, and it will be
115
+ accessible during this pipeline step (e.g. when calculating metrics)
116
+ stepper: Component tracking the current step.
117
+ """
118
+
119
+ pipeline_results: Mapping[str, torch.Tensor]
120
+ state: "PipelineState"
121
+ stepper: "Stepper"
122
+
123
+
124
+ @dataclasses.dataclass(kw_only=True)
125
+ class ComputeLossResult:
126
+ """
127
+ The result of the loss computation.
128
+
129
+ Attributes:
130
+ loss: The scalar tensor representing the loss to be backpropagated.
131
+ loss_weight: The weight to apply to the loss (for synchronizing gradients using weighted mean).
132
+ None for 1.0.
133
+ """
134
+
135
+ loss: torch.Tensor
136
+ loss_weight: torch.Tensor | None
137
+
138
+
139
+ @dataclasses.dataclass(kw_only=True)
140
+ class CreateMetricsContext:
141
+ """Context data provided to initialize metrics."""
142
+
143
+
144
+ @dataclasses.dataclass(kw_only=True)
145
+ class CreateMetricsResult:
146
+ """
147
+ Result of metric initialization.
148
+
149
+ Attributes:
150
+ metrics: A dictionary mapping metric names to Metric instances.
151
+ """
152
+
153
+ metrics: dict[str, "Metric"]
154
+
155
+
156
+ @dataclasses.dataclass(kw_only=True)
157
+ class UpdateMetricsContext:
158
+ """
159
+ Context data provided to update metrics after a step.
160
+
161
+ Attributes:
162
+ state: The current state of the pipeline.
163
+ metrics: The dictionary of metrics to be updated.
164
+ """
165
+
166
+ state: "PipelineState"
167
+ metrics: Mapping[str, "Metric"]
168
+
169
+
170
+ class TrainTask(BaseTask, abc.ABC, typing.Generic[TBatch]):
171
+ """Abstract base class for defining training-specific logic."""
172
+
173
+ @abc.abstractmethod
174
+ def compute_loss(self, ctx: ComputeLossContext) -> ComputeLossResult:
175
+ """
176
+ Calculates the loss based on model outputs.
177
+
178
+ Args:
179
+ ctx: Context object.
180
+
181
+ Returns:
182
+ Result object.
183
+ """
184
+
185
+ ...
186
+
187
+ def create_metrics(self, ctx: CreateMetricsContext) -> CreateMetricsResult:
188
+ """
189
+ Initializes metrics to be tracked during training.
190
+
191
+ Args:
192
+ ctx: Context object.
193
+
194
+ Returns:
195
+ Result object.
196
+ """
197
+
198
+ return CreateMetricsResult(metrics={})
199
+
200
+ def update_metrics(self, ctx: UpdateMetricsContext):
201
+ """
202
+ Updates the state of the metrics at the end of training step.
203
+
204
+ Args:
205
+ ctx: Context object.
206
+ """
207
+
208
+ def dump_hparams(self) -> ScalarTree:
209
+ """
210
+ Exports hyperparameters associated with this task for logging.
211
+
212
+ Returns:
213
+ A dictionary of hyperparameter names and values.
214
+ """
215
+
216
+ return {}
217
+
218
+
219
+ @dataclasses.dataclass(kw_only=True)
220
+ class TrainTaskProviderContext:
221
+ """
222
+ Context data provided to the factory creating a TrainTask.
223
+
224
+ Attributes:
225
+ dist_context: Information about the distributed environment.
226
+ """
227
+
228
+ dist_context: DistributedContext
229
+
230
+
231
+ @typing.runtime_checkable
232
+ class TrainTaskProvider(Protocol):
233
+ """Protocol that creates a TrainTask instance."""
234
+
235
+ def __call__(self, ctx: TrainTaskProviderContext) -> TrainTask:
236
+ """
237
+ Creates and returns a new TrainTask.
238
+
239
+ Args:
240
+ ctx: Context object.
241
+
242
+ Returns:
243
+ An instantiated TrainTask.
244
+ """
245
+
246
+ ...
247
+
248
+
249
+ @dataclasses.dataclass(kw_only=True)
250
+ class ProcessOutputsContext:
251
+ """
252
+ Context data provided to process outputs during inference.
253
+
254
+ Attributes:
255
+ outputs: The outputs returned by the model's forward pass.
256
+ state: The current state of the pipeline.
257
+ """
258
+
259
+ outputs: dict[str, torch.Tensor]
260
+ state: "PipelineState"
261
+
262
+
263
+ class InferenceTask(BaseTask, abc.ABC, typing.Generic[TBatch]):
264
+ """Abstract base class for defining inference-specific logic."""
265
+
266
+ @abc.abstractmethod
267
+ def process_outputs(self, ctx: ProcessOutputsContext):
268
+ """
269
+ Processes the model outputs (e.g. saving to disk, decoding tokens).
270
+
271
+ Args:
272
+ ctx: Context containing the model outputs and pipeline state.
273
+ """
274
+
275
+ ...
276
+
277
+
278
+ @dataclasses.dataclass(kw_only=True)
279
+ class InferenceTaskProviderContext:
280
+ """
281
+ Context data provided to the factory creating an InferenceTask.
282
+
283
+ Attributes:
284
+ dist_context: Information about the distributed environment.
285
+ """
286
+
287
+ dist_context: DistributedContext
288
+
289
+
290
+ @typing.runtime_checkable
291
+ class InferenceTaskProvider(Protocol):
292
+ """Protocol for a callable that creates an InferenceTask instance."""
293
+
294
+ def __call__(self, ctx: InferenceTaskProviderContext) -> InferenceTask:
295
+ """
296
+ Creates and returns a new InferenceTask.
297
+
298
+ Args:
299
+ ctx: Context providing distributed environment information.
300
+
301
+ Returns:
302
+ An instantiated InferenceTask.
303
+ """
304
+ ...
@@ -0,0 +1,6 @@
1
+ from .train import Trainer, TrainingConfigurator
2
+
3
+ __all__ = [
4
+ "Trainer",
5
+ "TrainingConfigurator"
6
+ ]