d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
4
|
+
|
|
5
|
+
from d9d.loop.config import StepActionPeriod, StepActionSpecial
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Stepper(Stateful):
|
|
9
|
+
def __init__(self, initial_step: int, total_steps: int):
|
|
10
|
+
self._current_step = initial_step
|
|
11
|
+
self._total_steps = total_steps
|
|
12
|
+
|
|
13
|
+
def step(self):
|
|
14
|
+
self._current_step += 1
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def current_step(self) -> int:
|
|
18
|
+
return self._current_step
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def total_steps(self) -> int:
|
|
22
|
+
return self._total_steps
|
|
23
|
+
|
|
24
|
+
def state_dict(self) -> dict[str, Any]:
|
|
25
|
+
return {
|
|
26
|
+
"current_step": self._current_step,
|
|
27
|
+
"total_steps": self._total_steps
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
31
|
+
if state_dict["total_steps"] != self._total_steps:
|
|
32
|
+
raise ValueError(f'Step count differs: saved {state_dict["total_steps"]}, '
|
|
33
|
+
f'current {self._total_steps}. Perhaps project configuration changed?')
|
|
34
|
+
|
|
35
|
+
self._current_step = state_dict["current_step"]
|
|
36
|
+
|
|
37
|
+
def should_do_action(self, action: StepActionPeriod, enable_on_last_step_if_periodic: bool = False) -> bool:
|
|
38
|
+
match action:
|
|
39
|
+
case StepActionSpecial.disable:
|
|
40
|
+
return False
|
|
41
|
+
case StepActionSpecial.last_step:
|
|
42
|
+
return self._current_step == self._total_steps
|
|
43
|
+
case int():
|
|
44
|
+
if action <= 0:
|
|
45
|
+
raise ValueError()
|
|
46
|
+
|
|
47
|
+
will_do_periodic = self._current_step % action == 0
|
|
48
|
+
will_do_last = enable_on_last_step_if_periodic and self._current_step == self._total_steps
|
|
49
|
+
|
|
50
|
+
return will_do_periodic or will_do_last
|
|
51
|
+
case _:
|
|
52
|
+
raise ValueError("Invalid step configuration")
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
|
|
3
|
+
from d9d.core.dist_context import DistributedContext
|
|
4
|
+
from d9d.loop.config.config import TimeoutConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TimeoutState(StrEnum):
|
|
8
|
+
none = "none"
|
|
9
|
+
set_initial = "set_initial"
|
|
10
|
+
set_regular = "set_regular"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TimeoutManager:
|
|
14
|
+
"""
|
|
15
|
+
Manages the dynamic adjustment of distributed timeouts during the job loop.
|
|
16
|
+
|
|
17
|
+
This manager handles the transition from initialization timeouts (which may need
|
|
18
|
+
to be longer due to JIT compilation, caching, or startup overhead) to regular
|
|
19
|
+
step execution timeouts.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
dist_context: DistributedContext,
|
|
25
|
+
config: TimeoutConfig
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Constructs the timeout manager.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
dist_context: The distributed context where timeouts are applied.
|
|
32
|
+
config: Configuration containing initialization and step timeout values.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
self._dist_context = dist_context
|
|
36
|
+
self._config = config
|
|
37
|
+
self._state = TimeoutState.none
|
|
38
|
+
|
|
39
|
+
def step(self):
|
|
40
|
+
"""
|
|
41
|
+
Updates the distributed backend timeout based on the current phase.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
match self._state:
|
|
45
|
+
case TimeoutState.none:
|
|
46
|
+
self._dist_context.set_timeout(self._config.init_timeout)
|
|
47
|
+
self._state = TimeoutState.set_initial
|
|
48
|
+
case TimeoutState.set_initial:
|
|
49
|
+
self._dist_context.set_timeout(self._config.step_timeout)
|
|
50
|
+
self._state = TimeoutState.set_regular
|
|
51
|
+
case TimeoutState.set_regular:
|
|
52
|
+
pass # do nothing
|
|
53
|
+
case _:
|
|
54
|
+
raise ValueError("Unknown timeout state")
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from d9d.core.dist_context import DistributedContext
|
|
6
|
+
from d9d.core.types import PyTree
|
|
7
|
+
from d9d.internals.pipeline_state import PipelineStateHandler
|
|
8
|
+
from d9d.loop.control import BuildForwardInputsContext, BuildForwardInputsResult, TrainTask, UpdateMetricsContext
|
|
9
|
+
from d9d.metric.impl import ComposeMetric
|
|
10
|
+
from d9d.pipelining.factory.factory import PipelineScheduleInfo
|
|
11
|
+
|
|
12
|
+
from .loss_computer import STATE_LOSS, STATE_LOSS_WEIGHT, LossComputer
|
|
13
|
+
from .model_stage_factory import TrackedModules
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass(kw_only=True)
|
|
17
|
+
class ForwardResult:
|
|
18
|
+
"""
|
|
19
|
+
Encapsulates the scalar results of a forward pass step.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
loss: The computed loss tensor.
|
|
23
|
+
loss_weight: The weight associated with this loss (usually batch size).
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
loss: torch.Tensor
|
|
27
|
+
loss_weight: torch.Tensor
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TrainTaskOperator:
|
|
31
|
+
"""
|
|
32
|
+
Orchestrates the execution of the forward and backward passes for a specific training task.
|
|
33
|
+
|
|
34
|
+
This class abstracts the difference between standard execution
|
|
35
|
+
and pipeline-parallel execution. It manages input construction, schedule execution,
|
|
36
|
+
loss computation, and metric updates within the lifecycle of a single step.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
dist_context: DistributedContext,
|
|
42
|
+
task: TrainTask,
|
|
43
|
+
pp_schedule: PipelineScheduleInfo | None,
|
|
44
|
+
tracked_modules: TrackedModules,
|
|
45
|
+
loss_computer: LossComputer,
|
|
46
|
+
pipeline_state: PipelineStateHandler,
|
|
47
|
+
metrics: ComposeMetric
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Constructs the TrainTaskOperator.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
dist_context: The distributed context.
|
|
54
|
+
task: The user-defined training task logic.
|
|
55
|
+
pp_schedule: Information about the pipeline schedule.
|
|
56
|
+
tracked_modules: The model modules being trained.
|
|
57
|
+
loss_computer: Component responsible for calculating loss from outputs.
|
|
58
|
+
pipeline_state: Handler for transient state storage during the step.
|
|
59
|
+
metrics: Metric collection to update after the pass.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
self._dist_context = dist_context
|
|
63
|
+
self._task = task
|
|
64
|
+
self._pp_schedule = pp_schedule
|
|
65
|
+
self._tracked_modules = tracked_modules
|
|
66
|
+
self._loss_computer = loss_computer
|
|
67
|
+
self._pipeline_state = pipeline_state
|
|
68
|
+
self._metrics = metrics
|
|
69
|
+
|
|
70
|
+
def _forward_backward_pipelining(self, model_inputs: BuildForwardInputsResult):
|
|
71
|
+
if self._pp_schedule is None:
|
|
72
|
+
raise ValueError("Cannot run pipelined pass if pipelining is disabled")
|
|
73
|
+
|
|
74
|
+
self._pp_schedule.schedule.configure_buffers(
|
|
75
|
+
inputs=model_inputs.inputs,
|
|
76
|
+
kwargs=model_inputs.kwargs,
|
|
77
|
+
sharding_spec=model_inputs.pipeline_sharding_spec
|
|
78
|
+
)
|
|
79
|
+
self._pp_schedule.schedule.step(
|
|
80
|
+
inputs=model_inputs.inputs,
|
|
81
|
+
kwargs=model_inputs.kwargs
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _forward_backward_regular(self, model_inputs: BuildForwardInputsResult):
|
|
85
|
+
pipeline_outputs = self._tracked_modules(
|
|
86
|
+
**model_inputs.inputs,
|
|
87
|
+
**model_inputs.kwargs
|
|
88
|
+
)
|
|
89
|
+
loss = self._loss_computer.compute_loss_mul_weight(
|
|
90
|
+
pipeline_outputs=pipeline_outputs,
|
|
91
|
+
microbatch_idx=None
|
|
92
|
+
)
|
|
93
|
+
# free to avoid bwd peaking memory
|
|
94
|
+
del pipeline_outputs
|
|
95
|
+
loss.backward()
|
|
96
|
+
|
|
97
|
+
def forward_backward(self, batch: PyTree) -> ForwardResult | None:
|
|
98
|
+
"""
|
|
99
|
+
Executes the forward and backward passes for a single batch.
|
|
100
|
+
|
|
101
|
+
This method handles:
|
|
102
|
+
|
|
103
|
+
1. Context preparation and input building via the `TrainTask`.
|
|
104
|
+
2. Execution via Pipeline Parallel schedule or standard Forward/Backward.
|
|
105
|
+
3. Metric updates based on the results.
|
|
106
|
+
4. Reliable cleanup of the pipeline state.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
batch: The input batch data.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A `ForwardResult` containing the loss and weight if this rank is responsible
|
|
113
|
+
for loss calculation (e.g., the last pipeline stage or in standard DP).
|
|
114
|
+
Returns `None` if this rank is an intermediate pipeline stage that does
|
|
115
|
+
not compute loss.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
# Do forward and backward pass
|
|
120
|
+
model_inputs = self._task.build_forward_inputs(
|
|
121
|
+
BuildForwardInputsContext(
|
|
122
|
+
batch=batch,
|
|
123
|
+
state=self._pipeline_state.global_state()
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if self._dist_context.mesh_params.has_pipeline_parallel:
|
|
128
|
+
self._forward_backward_pipelining(model_inputs)
|
|
129
|
+
else:
|
|
130
|
+
self._forward_backward_regular(model_inputs)
|
|
131
|
+
|
|
132
|
+
# Update metrics if possible
|
|
133
|
+
|
|
134
|
+
pipeline_state = self._pipeline_state.global_state()
|
|
135
|
+
|
|
136
|
+
if (
|
|
137
|
+
self._dist_context.mesh_params.has_pipeline_parallel and
|
|
138
|
+
self._pp_schedule is not None and
|
|
139
|
+
not self._pp_schedule.has_last_stage
|
|
140
|
+
):
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
self._task.update_metrics(UpdateMetricsContext(
|
|
144
|
+
state=pipeline_state,
|
|
145
|
+
metrics=self._metrics.children
|
|
146
|
+
))
|
|
147
|
+
return ForwardResult(
|
|
148
|
+
loss=pipeline_state[STATE_LOSS],
|
|
149
|
+
loss_weight=pipeline_state[STATE_LOSS_WEIGHT]
|
|
150
|
+
)
|
|
151
|
+
finally:
|
|
152
|
+
self._pipeline_state.reset()
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from .config import (
|
|
2
|
+
BatchingConfig,
|
|
3
|
+
CheckpointingConfig,
|
|
4
|
+
DataLoadingConfig,
|
|
5
|
+
DeterminismConfig,
|
|
6
|
+
GarbageCollectionConfig,
|
|
7
|
+
GradientClippingConfig,
|
|
8
|
+
GradientManagerConfig,
|
|
9
|
+
InferenceConfig,
|
|
10
|
+
JobLoggerConfig,
|
|
11
|
+
ModelStageFactoryConfig,
|
|
12
|
+
PipeliningConfig,
|
|
13
|
+
ProfilingConfig,
|
|
14
|
+
TimeoutConfig,
|
|
15
|
+
TrainerConfig,
|
|
16
|
+
)
|
|
17
|
+
from .types import StepActionPeriod, StepActionSpecial
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"BatchingConfig",
|
|
21
|
+
"CheckpointingConfig",
|
|
22
|
+
"DataLoadingConfig",
|
|
23
|
+
"DeterminismConfig",
|
|
24
|
+
"GarbageCollectionConfig",
|
|
25
|
+
"GradientClippingConfig",
|
|
26
|
+
"GradientManagerConfig",
|
|
27
|
+
"InferenceConfig",
|
|
28
|
+
"JobLoggerConfig",
|
|
29
|
+
"ModelStageFactoryConfig",
|
|
30
|
+
"PipeliningConfig",
|
|
31
|
+
"ProfilingConfig",
|
|
32
|
+
"StepActionPeriod",
|
|
33
|
+
"StepActionSpecial",
|
|
34
|
+
"TimeoutConfig",
|
|
35
|
+
"TrainerConfig"
|
|
36
|
+
]
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from d9d.pipelining.factory import AnyPipelineScheduleConfig
|
|
6
|
+
from d9d.tracker import AnyTrackerConfig, RunConfig
|
|
7
|
+
|
|
8
|
+
from .types import StepActionPeriod
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BatchingConfig(BaseModel):
|
|
12
|
+
"""
|
|
13
|
+
Configuration for batch sizing logic.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
global_batch_size: The total effective batch size across all distributed
|
|
17
|
+
replicas and gradient accumulation steps.
|
|
18
|
+
microbatch_size: The distinct batch size fed into the model during a single
|
|
19
|
+
forward pass on a single device.
|
|
20
|
+
"""
|
|
21
|
+
global_batch_size: int
|
|
22
|
+
microbatch_size: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DeterminismConfig(BaseModel):
|
|
26
|
+
"""
|
|
27
|
+
Configuration for reproducibility and random number generation.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
base_seed: The base integer seed used to initialize random number
|
|
31
|
+
generators (Python, NumPy, PyTorch) across all ranks.
|
|
32
|
+
"""
|
|
33
|
+
base_seed: int
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PipeliningConfig(BaseModel):
|
|
37
|
+
"""
|
|
38
|
+
Configuration for pipeline parallelism orchestration.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
schedule: The specific scheduling strategy configuration used to manage pipeline execution.
|
|
42
|
+
"""
|
|
43
|
+
schedule: AnyPipelineScheduleConfig
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GarbageCollectionConfig(BaseModel):
|
|
47
|
+
"""
|
|
48
|
+
Configuration for manual Python garbage collection control.
|
|
49
|
+
|
|
50
|
+
Attributes:
|
|
51
|
+
period_steps: How frequently to manually trigger the Python garbage collector.
|
|
52
|
+
"""
|
|
53
|
+
period_steps: StepActionPeriod
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DataLoadingConfig(BaseModel):
|
|
57
|
+
"""
|
|
58
|
+
Configuration for PyTorch DataLoaders.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
num_workers: The number of subprocesses to use for data loading.
|
|
62
|
+
pin_memory: Whether to copy tensors into CUDA pinned memory before returning them.
|
|
63
|
+
persistent_workers: If True, the data loader will not shutdown the worker processes
|
|
64
|
+
after a dataset has been consumed once.
|
|
65
|
+
"""
|
|
66
|
+
num_workers: int
|
|
67
|
+
pin_memory: bool
|
|
68
|
+
persistent_workers: bool
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CheckpointingConfig(BaseModel):
|
|
72
|
+
"""
|
|
73
|
+
Configuration for saving model snapshots.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
save_dir: The root directory where checkpoints will be stored.
|
|
77
|
+
period_steps: How frequently to save a checkpoint.
|
|
78
|
+
num_to_keep: The maximum number of recent checkpoints to retain. If None,
|
|
79
|
+
all checkpoints are kept.
|
|
80
|
+
"""
|
|
81
|
+
save_dir: Path
|
|
82
|
+
period_steps: StepActionPeriod
|
|
83
|
+
num_to_keep: int | None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ModelStageFactoryConfig(BaseModel):
|
|
87
|
+
"""
|
|
88
|
+
Configuration for initializing model weights.
|
|
89
|
+
|
|
90
|
+
Attributes:
|
|
91
|
+
source_checkpoint: Path to an initial checkpoint to load into the model
|
|
92
|
+
before training starts. If None, random initialization is used.
|
|
93
|
+
checkpoint_only_trainable_parameters: If True, only parameters with
|
|
94
|
+
requires_grad=True will be saved in checkpoints. Useful for PEFT/LoRA.
|
|
95
|
+
"""
|
|
96
|
+
source_checkpoint: Path | None
|
|
97
|
+
checkpoint_only_trainable_parameters: bool
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class GradientClippingConfig(BaseModel):
|
|
101
|
+
"""
|
|
102
|
+
Configuration for gradient norm clipping.
|
|
103
|
+
|
|
104
|
+
Attributes:
|
|
105
|
+
max_norm: The maximum norm value for gradient clipping. If None,
|
|
106
|
+
no clipping is performed.
|
|
107
|
+
log_total_steps: Frequency at which to log the total gradient norm.
|
|
108
|
+
"""
|
|
109
|
+
max_norm: float | None
|
|
110
|
+
log_total_steps: StepActionPeriod
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ProfilingConfig(BaseModel):
|
|
114
|
+
"""
|
|
115
|
+
Configuration for the PyTorch Profiler.
|
|
116
|
+
|
|
117
|
+
Attributes:
|
|
118
|
+
enabled: Whether to enable the profiler.
|
|
119
|
+
traces_dir: Directory where trace files will be saved.
|
|
120
|
+
period_steps: Total length of a profiling cycle (wait + warmup + active).
|
|
121
|
+
warmup_steps: Number of steps to ignore before recording to allow for warming-up.
|
|
122
|
+
active_steps: Number of steps to actively record traces.
|
|
123
|
+
"""
|
|
124
|
+
enabled: bool
|
|
125
|
+
|
|
126
|
+
traces_dir: Path
|
|
127
|
+
|
|
128
|
+
period_steps: int
|
|
129
|
+
warmup_steps: int
|
|
130
|
+
active_steps: int
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class JobLoggerConfig(BaseModel):
|
|
134
|
+
"""
|
|
135
|
+
Configuration for experiment tracking and logging.
|
|
136
|
+
|
|
137
|
+
Attributes:
|
|
138
|
+
period_steps: How frequently metrics are flushed to the logger.
|
|
139
|
+
tracker: Logic for the specific tracking backend (e.g., WandB, MLflow, stdout).
|
|
140
|
+
"""
|
|
141
|
+
period_steps: StepActionPeriod
|
|
142
|
+
tracker: AnyTrackerConfig
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class GradientManagerConfig(BaseModel):
|
|
146
|
+
"""
|
|
147
|
+
Configuration for gradient synchronization.
|
|
148
|
+
|
|
149
|
+
Attributes:
|
|
150
|
+
grad_dtype: The data type to use for storing the gradient. If None, follows the model's dtype.
|
|
151
|
+
bucket_size_mb: The size of gradient buckets in Megabytes for communication.
|
|
152
|
+
"""
|
|
153
|
+
grad_dtype: str | None
|
|
154
|
+
bucket_size_mb: int
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class TimeoutConfig(BaseModel):
|
|
158
|
+
"""
|
|
159
|
+
Configuration for distributed process group timeouts.
|
|
160
|
+
|
|
161
|
+
Attributes:
|
|
162
|
+
init_timeout: Timeout in seconds for initializing the process group.
|
|
163
|
+
step_timeout: Timeout in seconds for individual step communications.
|
|
164
|
+
"""
|
|
165
|
+
init_timeout: int = 10000
|
|
166
|
+
step_timeout: int = 100
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class TrainerConfig(BaseModel):
|
|
170
|
+
"""
|
|
171
|
+
Top-level configuration object defining a complete training job.
|
|
172
|
+
|
|
173
|
+
Attributes:
|
|
174
|
+
run: Meta-information about the run (name, ID, tags).
|
|
175
|
+
batching: Batch sizing strategy.
|
|
176
|
+
data_loading: DataLoader settings.
|
|
177
|
+
logging: Experiment tracking settings.
|
|
178
|
+
pipelining: Pipeline Parallelism schedule and settings. If None,
|
|
179
|
+
pipeline parallelism is disabled.
|
|
180
|
+
model_stage_factory: Model initialization and additional checkpointing logic.
|
|
181
|
+
determinism: Random seed settings.
|
|
182
|
+
gc: Garbage collection settings.
|
|
183
|
+
checkpointing: Checkpoint saving settings.
|
|
184
|
+
gradient_clipping: Gradient clipping settings.
|
|
185
|
+
profiling: Profiler settings.
|
|
186
|
+
gradient_manager: Gradient Synchronization Settings.
|
|
187
|
+
timeout: Distributed timeout settings.
|
|
188
|
+
"""
|
|
189
|
+
run: RunConfig
|
|
190
|
+
batching: BatchingConfig
|
|
191
|
+
data_loading: DataLoadingConfig
|
|
192
|
+
logging: JobLoggerConfig
|
|
193
|
+
pipelining: PipeliningConfig | None
|
|
194
|
+
model_stage_factory: ModelStageFactoryConfig
|
|
195
|
+
determinism: DeterminismConfig
|
|
196
|
+
gc: GarbageCollectionConfig
|
|
197
|
+
checkpointing: CheckpointingConfig
|
|
198
|
+
gradient_clipping: GradientClippingConfig
|
|
199
|
+
profiling: ProfilingConfig | None
|
|
200
|
+
gradient_manager: GradientManagerConfig
|
|
201
|
+
timeout: TimeoutConfig
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class InferenceConfig(BaseModel):
|
|
205
|
+
"""
|
|
206
|
+
Top-level configuration object defining an inference/evaluation job.
|
|
207
|
+
|
|
208
|
+
Attributes:
|
|
209
|
+
batching: Batch sizing strategy.
|
|
210
|
+
data_loading: DataLoader settings.
|
|
211
|
+
model_stage_factory: Model initialization logic.
|
|
212
|
+
determinism: Random seed settings.
|
|
213
|
+
gc: Garbage collection settings.
|
|
214
|
+
checkpointing: Checkpointing settings.
|
|
215
|
+
profiling: Profiler settings.
|
|
216
|
+
timeout: Distributed timeout settings.
|
|
217
|
+
"""
|
|
218
|
+
batching: BatchingConfig
|
|
219
|
+
data_loading: DataLoadingConfig
|
|
220
|
+
model_stage_factory: ModelStageFactoryConfig
|
|
221
|
+
determinism: DeterminismConfig
|
|
222
|
+
gc: GarbageCollectionConfig
|
|
223
|
+
checkpointing: CheckpointingConfig
|
|
224
|
+
profiling: ProfilingConfig | None
|
|
225
|
+
timeout: TimeoutConfig
|
d9d/loop/config/types.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class StepActionSpecial(StrEnum):
|
|
5
|
+
"""
|
|
6
|
+
Special flag values for configuring periodic actions.
|
|
7
|
+
|
|
8
|
+
Attributes:
|
|
9
|
+
last_step: Indicates the action should occur exactly once at the
|
|
10
|
+
very end of the training run.
|
|
11
|
+
disable: Indicates the action should never occur.
|
|
12
|
+
"""
|
|
13
|
+
last_step = "last_step"
|
|
14
|
+
disable = "disable"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
StepActionPeriod = int | StepActionSpecial
|
|
18
|
+
"""
|
|
19
|
+
Union type representing a configuration for periodic events.
|
|
20
|
+
|
|
21
|
+
Values:
|
|
22
|
+
int: The period in steps (frequency) at which the event occurs.
|
|
23
|
+
StepActionSpecial: A special flag indicating end-of-run execution or disabling.
|
|
24
|
+
"""
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from .dataset_provider import DatasetProvider, InitializeDatasetContext, InitializeDatasetResult
|
|
2
|
+
from .lr_scheduler_provider import InitializeLRSchedulerContext, LRSchedulerProvider
|
|
3
|
+
from .model_provider import (
|
|
4
|
+
InitializeModelStageContext,
|
|
5
|
+
InitializeModelStageResult,
|
|
6
|
+
ModelProvider,
|
|
7
|
+
ParallelizeModelStageContext,
|
|
8
|
+
PrepareExportModelStageContext,
|
|
9
|
+
PrepareExportModelStageResult,
|
|
10
|
+
)
|
|
11
|
+
from .optimizer_provider import InitializeOptimizerStageContext, OptimizerProvider
|
|
12
|
+
from .task import (
|
|
13
|
+
BaseTask,
|
|
14
|
+
BuildForwardInputsContext,
|
|
15
|
+
BuildForwardInputsResult,
|
|
16
|
+
ComputeLossContext,
|
|
17
|
+
ComputeLossResult,
|
|
18
|
+
CreateMetricsContext,
|
|
19
|
+
CreateMetricsResult,
|
|
20
|
+
FinalizeContext,
|
|
21
|
+
InferenceTask,
|
|
22
|
+
InferenceTaskProvider,
|
|
23
|
+
InferenceTaskProviderContext,
|
|
24
|
+
ProcessOutputsContext,
|
|
25
|
+
TrainTask,
|
|
26
|
+
TrainTaskProvider,
|
|
27
|
+
TrainTaskProviderContext,
|
|
28
|
+
UpdateMetricsContext,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"BaseTask",
|
|
33
|
+
"BuildForwardInputsContext",
|
|
34
|
+
"BuildForwardInputsResult",
|
|
35
|
+
"ComputeLossContext",
|
|
36
|
+
"ComputeLossResult",
|
|
37
|
+
"CreateMetricsContext",
|
|
38
|
+
"CreateMetricsResult",
|
|
39
|
+
"DatasetProvider",
|
|
40
|
+
"FinalizeContext",
|
|
41
|
+
"InferenceTask",
|
|
42
|
+
"InferenceTaskProvider",
|
|
43
|
+
"InferenceTaskProviderContext",
|
|
44
|
+
"InitializeDatasetContext",
|
|
45
|
+
"InitializeDatasetResult",
|
|
46
|
+
"InitializeLRSchedulerContext",
|
|
47
|
+
"InitializeModelStageContext",
|
|
48
|
+
"InitializeModelStageResult",
|
|
49
|
+
"InitializeOptimizerStageContext",
|
|
50
|
+
"LRSchedulerProvider",
|
|
51
|
+
"ModelProvider",
|
|
52
|
+
"OptimizerProvider",
|
|
53
|
+
"ParallelizeModelStageContext",
|
|
54
|
+
"PrepareExportModelStageContext",
|
|
55
|
+
"PrepareExportModelStageResult",
|
|
56
|
+
"ProcessOutputsContext",
|
|
57
|
+
"TrainTask",
|
|
58
|
+
"TrainTaskProvider",
|
|
59
|
+
"TrainTaskProviderContext",
|
|
60
|
+
"UpdateMetricsContext"
|
|
61
|
+
]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import typing
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from torch.utils.data import Dataset
|
|
6
|
+
|
|
7
|
+
from d9d.core.dist_context import DistributedContext
|
|
8
|
+
from d9d.core.types import CollateFn
|
|
9
|
+
|
|
10
|
+
if typing.TYPE_CHECKING:
|
|
11
|
+
from d9d.loop.component import BatchMaths
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass(kw_only=True)
|
|
15
|
+
class InitializeDatasetContext:
|
|
16
|
+
"""Context data required to initialize a dataset provider.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
dist_context: The distributed context containing rank and world size information.
|
|
20
|
+
batch_maths: The batch maths component handling global batch size calculations.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
dist_context: DistributedContext
|
|
24
|
+
batch_maths: "BatchMaths"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclasses.dataclass(kw_only=True)
|
|
28
|
+
class InitializeDatasetResult:
|
|
29
|
+
"""The result of initializing a dataset provider.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
dataset: The instantiated PyTorch Dataset.
|
|
33
|
+
collator: The function used to collate individual samples into a batch.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
dataset: Dataset
|
|
37
|
+
collator: CollateFn
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@typing.runtime_checkable
|
|
41
|
+
class DatasetProvider(Protocol):
|
|
42
|
+
"""Protocol that allows users to define how datasets are loaded and collated.
|
|
43
|
+
|
|
44
|
+
Users should subclass this to provide custom data loading logic.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __call__(self, context: InitializeDatasetContext) -> InitializeDatasetResult:
|
|
48
|
+
"""
|
|
49
|
+
Initializes the dataset components.
|
|
50
|
+
|
|
51
|
+
It is important that the user must shard the dataset manually, perhaps using `d9d.dataset.ShardedDataset`.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
context: Context for this operation.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Result of this operation.
|
|
58
|
+
"""
|