d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.utils._pytree as pytree # noqa: PLC2701
|
|
7
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
8
|
+
|
|
9
|
+
from d9d.core.dist_context import DistributedContext
|
|
10
|
+
from d9d.core.types import PyTree, ScalarTree
|
|
11
|
+
from d9d.internals.state import load_state_dict_main_process, state_dict_main_process
|
|
12
|
+
from d9d.loop.config import JobLoggerConfig
|
|
13
|
+
from d9d.metric.impl import ComposeMetric
|
|
14
|
+
from d9d.tracker import BaseTracker, BaseTrackerRun, RunConfig, tracker_from_config
|
|
15
|
+
from d9d.tracker.provider.null import NullTrackerConfig
|
|
16
|
+
|
|
17
|
+
from .stepper import Stepper
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _flatten_pytree_for_metrics(tree: PyTree[float]) -> dict[str, float]:
|
|
21
|
+
flat_dict = {}
|
|
22
|
+
|
|
23
|
+
for path_tuple, value in pytree.tree_leaves_with_path(tree):
|
|
24
|
+
path_segments = []
|
|
25
|
+
|
|
26
|
+
for key in path_tuple:
|
|
27
|
+
match key:
|
|
28
|
+
case pytree.MappingKey(k):
|
|
29
|
+
path_segments.append(str(k))
|
|
30
|
+
case pytree.SequenceKey(idx):
|
|
31
|
+
path_segments.append(str(idx))
|
|
32
|
+
case pytree.GetAttrKey(name):
|
|
33
|
+
path_segments.append(name)
|
|
34
|
+
case _:
|
|
35
|
+
path_segments.append(str(key))
|
|
36
|
+
|
|
37
|
+
flat_key = "/".join(path_segments)
|
|
38
|
+
flat_dict[flat_key] = value
|
|
39
|
+
|
|
40
|
+
return flat_dict
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class JobLogger(Stateful):
|
|
44
|
+
"""
|
|
45
|
+
Handles the logging of training metrics and loss values.
|
|
46
|
+
|
|
47
|
+
This class coordinates with the distributed context and metric calculators
|
|
48
|
+
to log instantaneous loss values and periodic aggregated metrics to the
|
|
49
|
+
configured experiment tracker.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
dist_context: DistributedContext,
|
|
55
|
+
config: JobLoggerConfig,
|
|
56
|
+
metrics: ComposeMetric,
|
|
57
|
+
stepper: Stepper,
|
|
58
|
+
run_config: RunConfig,
|
|
59
|
+
additional_hparams: ScalarTree
|
|
60
|
+
):
|
|
61
|
+
"""
|
|
62
|
+
Constructs JobLogger object.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
dist_context: The distributed context.
|
|
66
|
+
config: Configuration settings.
|
|
67
|
+
metrics: The composite metric collection to be computed and logged.
|
|
68
|
+
stepper: Object tracking the current global step.
|
|
69
|
+
run_config: Run configuration.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
self._dist_context = dist_context
|
|
73
|
+
self._config = config
|
|
74
|
+
self._metrics = metrics
|
|
75
|
+
self._stepper = stepper
|
|
76
|
+
self._run_config = run_config.model_copy(deep=True, update={"hparams": {
|
|
77
|
+
"run": run_config.hparams,
|
|
78
|
+
"params": additional_hparams
|
|
79
|
+
}})
|
|
80
|
+
|
|
81
|
+
self._tracker = self._build_tracker()
|
|
82
|
+
|
|
83
|
+
def _build_tracker(self) -> BaseTracker:
|
|
84
|
+
if self._dist_context.is_main_process:
|
|
85
|
+
return tracker_from_config(self._config.tracker)
|
|
86
|
+
else:
|
|
87
|
+
return tracker_from_config(NullTrackerConfig())
|
|
88
|
+
|
|
89
|
+
@contextmanager
|
|
90
|
+
def new_run(self) -> Generator[BaseTrackerRun, None, None]:
|
|
91
|
+
with self._tracker.open(self._run_config) as run:
|
|
92
|
+
yield run
|
|
93
|
+
|
|
94
|
+
def trigger_sync(self):
|
|
95
|
+
"""
|
|
96
|
+
Conditionally initiates the synchronization of distributed metrics.
|
|
97
|
+
|
|
98
|
+
Checks if the current step is scheduled for metric logging. If so, it
|
|
99
|
+
triggers the asynchronous communication required to aggregate metric values
|
|
100
|
+
across ranks. This allows communication to overlap with other operations
|
|
101
|
+
before `log` is called.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
if not self._stepper.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=True):
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
self._metrics.trigger_sync(self._dist_context)
|
|
108
|
+
|
|
109
|
+
def log(self, run: BaseTrackerRun, loss_value: torch.Tensor):
|
|
110
|
+
"""
|
|
111
|
+
Logs the current loss and conditionally processes aggregated metrics.
|
|
112
|
+
|
|
113
|
+
This method always logs the provided loss value. Periodically (determined
|
|
114
|
+
by the stepper and configuration), it waits for the synchronization of
|
|
115
|
+
metrics to complete (initiated by `trigger_sync`), computes their values,
|
|
116
|
+
flattens the result structure, logs them to the tracker, and resets the
|
|
117
|
+
metrics for the next window.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
run: The active tracker run interface for sending data.
|
|
121
|
+
loss_value: Tensor containing the scalar loss for the current step.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
run.scalar("loss", loss_value.item())
|
|
125
|
+
|
|
126
|
+
if not self._stepper.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=True):
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
self._metrics.wait_sync(self._dist_context)
|
|
130
|
+
|
|
131
|
+
results_tree = self._metrics.compute()
|
|
132
|
+
results_tree = pytree.tree_map(lambda x: x.item(), results_tree)
|
|
133
|
+
results_flat = _flatten_pytree_for_metrics(results_tree)
|
|
134
|
+
|
|
135
|
+
for name, value in results_flat.items():
|
|
136
|
+
run.scalar(name, value)
|
|
137
|
+
|
|
138
|
+
self._metrics.reset()
|
|
139
|
+
|
|
140
|
+
def state_dict(self) -> dict[str, Any]:
|
|
141
|
+
return {
|
|
142
|
+
"tracker": state_dict_main_process(self._dist_context, self._tracker),
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
146
|
+
load_state_dict_main_process(self._dist_context, self._tracker, state_dict["tracker"])
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
|
|
4
|
+
import torch.profiler
|
|
5
|
+
|
|
6
|
+
from d9d.core.dist_context import DistributedContext
|
|
7
|
+
from d9d.internals.profiling import Profiler
|
|
8
|
+
from d9d.loop.config import ProfilingConfig
|
|
9
|
+
|
|
10
|
+
from .stepper import Stepper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class JobProfiler:
|
|
14
|
+
"""
|
|
15
|
+
Manages profiling sessions during a job loop.
|
|
16
|
+
|
|
17
|
+
This class coordinates the initialization and activation of the internal
|
|
18
|
+
profiler based on the current step count provided by the stepper.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dist_context: DistributedContext,
|
|
24
|
+
config: ProfilingConfig | None,
|
|
25
|
+
stepper: Stepper
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Constructs JobProfiler object.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
dist_context: The distributed context.
|
|
32
|
+
config: Configuration settings for profiling.
|
|
33
|
+
stepper: Object tracking the current global step of the training loop.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
self._config = config
|
|
37
|
+
if config is None or not config.enabled:
|
|
38
|
+
self._profiler = None
|
|
39
|
+
else:
|
|
40
|
+
self._profiler = Profiler(
|
|
41
|
+
save_dir=config.traces_dir,
|
|
42
|
+
active_steps=config.active_steps,
|
|
43
|
+
warmup_steps=config.warmup_steps,
|
|
44
|
+
period_steps=config.period_steps,
|
|
45
|
+
dist_context=dist_context
|
|
46
|
+
)
|
|
47
|
+
self._stepper = stepper
|
|
48
|
+
|
|
49
|
+
@contextmanager
|
|
50
|
+
def open(self) -> Generator[torch.profiler.profile | None]:
|
|
51
|
+
"""
|
|
52
|
+
Context manager to activate profiling for the job loop.
|
|
53
|
+
|
|
54
|
+
Yields:
|
|
55
|
+
The active Profiler instance if profiling is enabled, otherwise None.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
if self._profiler is None:
|
|
59
|
+
yield None
|
|
60
|
+
else:
|
|
61
|
+
with self._profiler.open(self._stepper.current_step) as prof:
|
|
62
|
+
yield prof
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from d9d.internals.pipeline_state import PipelineStateHandler
|
|
4
|
+
from d9d.loop.control import ComputeLossContext, TrainTask
|
|
5
|
+
|
|
6
|
+
from .stepper import Stepper
|
|
7
|
+
|
|
8
|
+
STATE_LOSS = "__internal_loss"
|
|
9
|
+
STATE_LOSS_WEIGHT = "__internal_loss_weight"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LossComputer:
|
|
13
|
+
"""
|
|
14
|
+
Handles the computation of loss values and their integration into the pipeline state.
|
|
15
|
+
|
|
16
|
+
This component acts as a bridge between the raw outputs of the model pipeline
|
|
17
|
+
and the user-defined training task. It retrieves the appropriate state context
|
|
18
|
+
(potentially sharded per microbatch), executes the user's loss logic, persists
|
|
19
|
+
metrics into the state for logging, and returns the loss*weight term for backpropagation.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
state: PipelineStateHandler,
|
|
25
|
+
task: TrainTask,
|
|
26
|
+
stepper: Stepper
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Constructs a new LossComputer.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
state: Handler for managing global and sharded pipeline states.
|
|
33
|
+
task: The user-defined training task containing loss computation logic.
|
|
34
|
+
stepper: Component tracking current step and progress.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
self._state = state
|
|
38
|
+
self._task = task
|
|
39
|
+
self._stepper = stepper
|
|
40
|
+
|
|
41
|
+
def compute_loss_mul_weight(
|
|
42
|
+
self,
|
|
43
|
+
pipeline_outputs: dict[str, torch.Tensor],
|
|
44
|
+
microbatch_idx: int | None
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Computes the weighted loss for a specific sharded microbatch or the full microbatch.
|
|
48
|
+
|
|
49
|
+
This method retrieves the appropriate state context based on the microbatch
|
|
50
|
+
index, delegates calculation to the training task, saves the raw loss and
|
|
51
|
+
weight into the state for later retrieval, and returns the final scalar
|
|
52
|
+
product used for backward passes.
|
|
53
|
+
|
|
54
|
+
You can retrieve states by using `STATE_LOSS` and `STATE_LOSS_WEIGHT` keys.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
pipeline_outputs: Dictionary containing model output tensors.
|
|
58
|
+
microbatch_idx: Index of the current microbatch, or `None` for full microbatch execution.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The calculated loss multiplied by its weight.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
if microbatch_idx is None:
|
|
65
|
+
state = self._state.global_state()
|
|
66
|
+
else:
|
|
67
|
+
state = self._state.sharded_state(
|
|
68
|
+
shard_id=microbatch_idx
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
computation = self._task.compute_loss(ComputeLossContext(
|
|
72
|
+
pipeline_results=pipeline_outputs,
|
|
73
|
+
state=state,
|
|
74
|
+
stepper=self._stepper
|
|
75
|
+
))
|
|
76
|
+
|
|
77
|
+
loss = computation.loss
|
|
78
|
+
loss_weight = computation.loss_weight
|
|
79
|
+
|
|
80
|
+
if loss_weight is None:
|
|
81
|
+
loss_weight = torch.ones_like(loss)
|
|
82
|
+
|
|
83
|
+
state[STATE_LOSS] = loss[None]
|
|
84
|
+
state[STATE_LOSS_WEIGHT] = loss_weight[None]
|
|
85
|
+
|
|
86
|
+
return loss * loss_weight
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
4
|
+
from d9d.loop.control import ModelProvider, PrepareExportModelStageContext
|
|
5
|
+
from d9d.model_state.io import save_model_state_pipeline_parallel
|
|
6
|
+
from d9d.model_state.mapper.compose import ModelStateMapperParallel
|
|
7
|
+
|
|
8
|
+
from .model_stage_factory import TrackedModules
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelStageExporter:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
model_provider: ModelProvider,
|
|
15
|
+
modules: TrackedModules,
|
|
16
|
+
dist_context: DistributedContext
|
|
17
|
+
):
|
|
18
|
+
self._model_provider = model_provider
|
|
19
|
+
self._modules = modules
|
|
20
|
+
self._dist_context = dist_context
|
|
21
|
+
|
|
22
|
+
def export(self, save_dir: Path):
|
|
23
|
+
mappers = []
|
|
24
|
+
for stage in self._modules.modules:
|
|
25
|
+
result = self._model_provider.prepare_export_model_stage(PrepareExportModelStageContext(
|
|
26
|
+
model=stage,
|
|
27
|
+
dist_context=self._dist_context
|
|
28
|
+
))
|
|
29
|
+
mappers.append(result.state_mapper)
|
|
30
|
+
save_model_state_pipeline_parallel(
|
|
31
|
+
dest_dir=save_dir,
|
|
32
|
+
mapper=ModelStateMapperParallel(mappers),
|
|
33
|
+
device_mesh=self._dist_context.mesh_for(REGULAR_DOMAIN),
|
|
34
|
+
pipeline_dim_name="pp",
|
|
35
|
+
models=self._modules.modules,
|
|
36
|
+
show_progress=True,
|
|
37
|
+
)
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
8
|
+
|
|
9
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
10
|
+
from d9d.loop.config import ModelStageFactoryConfig, PipeliningConfig
|
|
11
|
+
from d9d.loop.control import InitializeModelStageContext, ModelProvider, ParallelizeModelStageContext
|
|
12
|
+
from d9d.model_state.io import load_model_state
|
|
13
|
+
from d9d.module.base import ModuleLateInit
|
|
14
|
+
from d9d.pipelining.api import PipelineStageInfo
|
|
15
|
+
from d9d.pipelining.factory.factory import PipelineScheduleInfo, build_schedule
|
|
16
|
+
|
|
17
|
+
from .batch_maths import BatchMaths
|
|
18
|
+
from .loss_computer import LossComputer
|
|
19
|
+
|
|
20
|
+
StatefulPredicate = Callable[[str, torch.Tensor], bool]
|
|
21
|
+
"""Determines if a specific parameter or buffer should be included in the state dictionary."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _stateful_predicate_requires_grad(key: str, value: torch.Tensor) -> bool:
|
|
25
|
+
"""Predicate that allows saving only tensors that require gradients."""
|
|
26
|
+
return value.requires_grad
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _stateful_predicate_always(key: str, value: torch.Tensor) -> bool:
|
|
30
|
+
"""Predicate that always allows saving."""
|
|
31
|
+
return True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TrackedModules(Stateful):
|
|
35
|
+
"""
|
|
36
|
+
Wraps a list of model stages and manages their state for distributed checkpointing.
|
|
37
|
+
|
|
38
|
+
This class implements the PyTorch Distributed `Stateful` protocol, aggregating
|
|
39
|
+
the state dictionaries of multiple pipeline stages assigned to the current rank.
|
|
40
|
+
It handles namespacing to ensure uniqueness across pipeline ranks and stages.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
dist_context: DistributedContext,
|
|
46
|
+
modules: list[nn.Module],
|
|
47
|
+
stateful_predicate: StatefulPredicate
|
|
48
|
+
):
|
|
49
|
+
"""Constructs a TrackedModules object."""
|
|
50
|
+
self._dist_context = dist_context
|
|
51
|
+
self._modules = modules
|
|
52
|
+
self._stateful_predicate = stateful_predicate
|
|
53
|
+
|
|
54
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
55
|
+
"""
|
|
56
|
+
Forwards execution to the only pipeline stage.
|
|
57
|
+
|
|
58
|
+
This method is only valid when pipeline parallelism is disabled.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
*args: Positional arguments passed to the module.
|
|
62
|
+
**kwargs: Keyword arguments passed to the module.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The output of the model execution.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If pipeline parallelism is configured.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
if self._dist_context.mesh_params.has_pipeline_parallel:
|
|
72
|
+
raise ValueError("You cannot call tracked modules when using pipelining")
|
|
73
|
+
|
|
74
|
+
return self._modules[0](*args, **kwargs)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def modules(self) -> list[nn.Module]:
|
|
78
|
+
"""Returns the list of underlying PyTorch model modules."""
|
|
79
|
+
return self._modules
|
|
80
|
+
|
|
81
|
+
def _whitelisted_params(self, module: nn.Module) -> set[str]:
|
|
82
|
+
allow_saving = set()
|
|
83
|
+
for param_name, param in itertools.chain(module.named_parameters(), module.named_buffers()):
|
|
84
|
+
if self._stateful_predicate(param_name, param):
|
|
85
|
+
allow_saving.add(param_name)
|
|
86
|
+
return allow_saving
|
|
87
|
+
|
|
88
|
+
def _state_dict_stage(self, module: nn.Module) -> dict[str, Any]:
|
|
89
|
+
whitelist = self._whitelisted_params(module)
|
|
90
|
+
result = {
|
|
91
|
+
k: v for k, v in module.state_dict().items() if k in whitelist
|
|
92
|
+
}
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
def state_dict(self) -> dict[str, Any]:
|
|
96
|
+
"""
|
|
97
|
+
Generates the state dictionary for all tracked modules.
|
|
98
|
+
|
|
99
|
+
The keys are namespaced using the current pipeline rank and stage index
|
|
100
|
+
(e.g., `pp_0_stage_0`). Only parameters satisfying the `stateful_predicate`
|
|
101
|
+
are included.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
A dictionary containing the states of all managed modules.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
pp_rank = self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"].get_local_rank()
|
|
108
|
+
ret = {
|
|
109
|
+
f"pp_{pp_rank}_stage_{i}": self._state_dict_stage(module)
|
|
110
|
+
for i, module in enumerate(self._modules)
|
|
111
|
+
}
|
|
112
|
+
return ret
|
|
113
|
+
|
|
114
|
+
def _load_state_dict_stage(self, module: nn.Module, state_dict: dict[str, Any]):
|
|
115
|
+
whitelist = self._whitelisted_params(module)
|
|
116
|
+
|
|
117
|
+
loading_result = module.load_state_dict(state_dict, strict=False)
|
|
118
|
+
missing_keys = set(loading_result.missing_keys)
|
|
119
|
+
extra_keys = set(loading_result.unexpected_keys)
|
|
120
|
+
|
|
121
|
+
if len(whitelist.intersection(missing_keys)) > 0:
|
|
122
|
+
raise ValueError(f"Missing keys: {whitelist.intersection(missing_keys)}")
|
|
123
|
+
if len(extra_keys) > 0:
|
|
124
|
+
raise ValueError(f"Extra keys: {extra_keys}")
|
|
125
|
+
|
|
126
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Loads the state dictionary into the tracked modules.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
state_dict: The state dictionary to load. Must contain keys corresponding
|
|
132
|
+
to the pipeline rank and stage indices managed by this instance.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
ValueError: If required keys are missing or unexpected keys are present
|
|
136
|
+
based on the allow-list predicate.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
pp_rank = self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"].get_local_rank()
|
|
140
|
+
for i, module in enumerate(self._modules):
|
|
141
|
+
self._load_state_dict_stage(module, state_dict[f"pp_{pp_rank}_stage_{i}"])
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class ModelStageFactory:
|
|
145
|
+
"""
|
|
146
|
+
Factory class responsible for creating, initializing, and parallelizing model stages.
|
|
147
|
+
|
|
148
|
+
This class coordinates the `ModelProvider` with the distributed context to:
|
|
149
|
+
|
|
150
|
+
1. Initialize models on a meta device.
|
|
151
|
+
2. Apply horizontal distribution strategy (TP, DP, FSDP, etc).
|
|
152
|
+
3. Materialize weights on the target device.
|
|
153
|
+
4. Load initial model states from checkpoints.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
model_provider: ModelProvider,
|
|
159
|
+
dist_context: DistributedContext,
|
|
160
|
+
batch_maths: BatchMaths,
|
|
161
|
+
config_model: ModelStageFactoryConfig,
|
|
162
|
+
config_pipelining: PipeliningConfig | None,
|
|
163
|
+
loss_computer: LossComputer | None
|
|
164
|
+
):
|
|
165
|
+
"""Constructs a ModelStageFactory object."""
|
|
166
|
+
|
|
167
|
+
self._model_provider = model_provider
|
|
168
|
+
self._dist_context = dist_context
|
|
169
|
+
self._config_model = config_model
|
|
170
|
+
self._config_pipelining = config_pipelining
|
|
171
|
+
self._batch_maths = batch_maths
|
|
172
|
+
self._loss_computer = loss_computer
|
|
173
|
+
|
|
174
|
+
def _build_model_stage(self, stage: PipelineStageInfo) -> nn.Module:
|
|
175
|
+
# create a model with no real memory occupied
|
|
176
|
+
with torch.device("meta"):
|
|
177
|
+
factored = self._model_provider.initialize_model_stage(
|
|
178
|
+
InitializeModelStageContext(
|
|
179
|
+
dist_context=self._dist_context,
|
|
180
|
+
stage=stage,
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
model = factored.model
|
|
185
|
+
|
|
186
|
+
if not isinstance(model, ModuleLateInit) or not isinstance(model, nn.Module):
|
|
187
|
+
raise ValueError("Model stage is required to be nn.Module instance implementing ModuleLateInit protocol")
|
|
188
|
+
|
|
189
|
+
# if current context is distributed - parallelize this model
|
|
190
|
+
if self._dist_context.mesh_params.is_distributed:
|
|
191
|
+
self._model_provider.parallelize_model_stage(
|
|
192
|
+
ParallelizeModelStageContext(
|
|
193
|
+
model=model,
|
|
194
|
+
stage=stage,
|
|
195
|
+
dist_context=self._dist_context
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# move state that is bound to current device to it
|
|
200
|
+
model.to_empty(device=self._dist_context.current_device)
|
|
201
|
+
|
|
202
|
+
# reinitialize model parameters (only these are on current device)
|
|
203
|
+
with torch.no_grad():
|
|
204
|
+
model.reset_parameters()
|
|
205
|
+
|
|
206
|
+
if self._config_model.source_checkpoint:
|
|
207
|
+
load_model_state(
|
|
208
|
+
src_dir=self._config_model.source_checkpoint,
|
|
209
|
+
model=model,
|
|
210
|
+
mapper=factored.state_mapper,
|
|
211
|
+
device=f"cuda:{torch.cuda.current_device()}"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# set training state
|
|
215
|
+
model.train()
|
|
216
|
+
|
|
217
|
+
return model
|
|
218
|
+
|
|
219
|
+
def build_pipeline_and_modules(
|
|
220
|
+
self
|
|
221
|
+
) -> tuple[PipelineScheduleInfo | None, TrackedModules]:
|
|
222
|
+
"""
|
|
223
|
+
Constructs the execution schedule and the model container.
|
|
224
|
+
|
|
225
|
+
If pipeline parallelism is enabled, this orchestrates the creation of a
|
|
226
|
+
distributed pipeline schedule.
|
|
227
|
+
|
|
228
|
+
Otherwise, it simply builds a standalone model stage.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
The pipeline schedule information (or None if no pipelining).
|
|
232
|
+
The `TrackedModules` instance wrapping the created model stage(s).
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
ValueError: If pipelining configuration is missing but a pipeline is requested.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
if self._config_model.checkpoint_only_trainable_parameters:
|
|
239
|
+
stateful_predicate = _stateful_predicate_requires_grad
|
|
240
|
+
else:
|
|
241
|
+
stateful_predicate = _stateful_predicate_always
|
|
242
|
+
|
|
243
|
+
if self._dist_context.mesh_params.has_pipeline_parallel:
|
|
244
|
+
if self._config_pipelining is None:
|
|
245
|
+
raise ValueError("Pipelining is enabled, but not configured")
|
|
246
|
+
|
|
247
|
+
loss_fn = self._loss_computer.compute_loss_mul_weight if self._loss_computer is not None else None
|
|
248
|
+
|
|
249
|
+
schedule, modules = build_schedule(
|
|
250
|
+
dist_context=self._dist_context,
|
|
251
|
+
n_microbatches=self._batch_maths.num_microbatches_pipelining,
|
|
252
|
+
schedule_config=self._config_pipelining.schedule,
|
|
253
|
+
model_provider=self._build_model_stage,
|
|
254
|
+
loss_fn=loss_fn
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
return schedule, TrackedModules(self._dist_context, modules, stateful_predicate)
|
|
258
|
+
else:
|
|
259
|
+
model = self._build_model_stage(PipelineStageInfo(num_stages=1, current_stage=0))
|
|
260
|
+
|
|
261
|
+
return None, TrackedModules(self._dist_context, [model], stateful_predicate)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
2
|
+
from d9d.core.protocol import LRSchedulerProtocol, OptimizerProtocol
|
|
3
|
+
from d9d.loop.control import (
|
|
4
|
+
InitializeLRSchedulerContext,
|
|
5
|
+
InitializeOptimizerStageContext,
|
|
6
|
+
LRSchedulerProvider,
|
|
7
|
+
OptimizerProvider,
|
|
8
|
+
)
|
|
9
|
+
from d9d.pipelining.training import PipelinedLRScheduler, PipelinedOptimizer
|
|
10
|
+
|
|
11
|
+
from .model_stage_factory import TrackedModules
|
|
12
|
+
from .stepper import Stepper
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OptimizerFactory:
|
|
16
|
+
"""
|
|
17
|
+
Factory for creating and configuring distributed optimizers and learning rate schedulers.
|
|
18
|
+
|
|
19
|
+
This factory handles the orchestration of optimizer creation for models potentially split across
|
|
20
|
+
pipeline stages. It uses the providers to instantiate underlying PyTorch optimizers and schedulers for each
|
|
21
|
+
tracked module, and wraps them in pipeline-aware interfaces.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
dist_context: DistributedContext,
|
|
27
|
+
tracked_modules: TrackedModules,
|
|
28
|
+
optimizer_provider: OptimizerProvider,
|
|
29
|
+
lr_scheduler_provider: LRSchedulerProvider,
|
|
30
|
+
stepper: Stepper
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Constructs the OptimizerFactory.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
dist_context: The distributed context.
|
|
37
|
+
tracked_modules: A container of model modules owned by the current rank.
|
|
38
|
+
optimizer_provider: A callable responsible for creating optimizer instances for a given model.
|
|
39
|
+
lr_scheduler_provider: A callable responsible for creating LR scheduler instances.
|
|
40
|
+
stepper: The training stepper providing information about total training steps.
|
|
41
|
+
"""
|
|
42
|
+
self._dist_context = dist_context
|
|
43
|
+
self._tracked_modules = tracked_modules
|
|
44
|
+
self._optimizer_provider = optimizer_provider
|
|
45
|
+
self._lr_scheduler_provider = lr_scheduler_provider
|
|
46
|
+
self._stepper = stepper
|
|
47
|
+
|
|
48
|
+
def build_optimizer_and_scheduler(self) -> tuple[OptimizerProtocol, LRSchedulerProtocol]:
|
|
49
|
+
"""
|
|
50
|
+
Builds both the optimizer and learning rate scheduler.
|
|
51
|
+
|
|
52
|
+
This method iterates through all local model modules. For each module, it creates an
|
|
53
|
+
optimizer and scheduler using the configured providers. Finally, it aggregates these individual
|
|
54
|
+
instances into a single `PipelinedOptimizer` and `PipelinedLRScheduler` capable of coordinated
|
|
55
|
+
stepping across the pipeline parallel dimension.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A tuple containing the initialized pipeline-aware optimizer and scheduler.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
optimizers: list[OptimizerProtocol] = []
|
|
62
|
+
lr_schedulers: list[LRSchedulerProtocol] = []
|
|
63
|
+
for module in self._tracked_modules.modules:
|
|
64
|
+
optimizer = self._optimizer_provider(
|
|
65
|
+
InitializeOptimizerStageContext(
|
|
66
|
+
dist_context=self._dist_context,
|
|
67
|
+
model=module
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
optimizers.append(optimizer)
|
|
71
|
+
|
|
72
|
+
scheduler = self._lr_scheduler_provider(
|
|
73
|
+
InitializeLRSchedulerContext(
|
|
74
|
+
dist_context=self._dist_context,
|
|
75
|
+
total_steps=self._stepper.total_steps,
|
|
76
|
+
optimizer=optimizer
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
lr_schedulers.append(scheduler)
|
|
80
|
+
pipe_optimizer = PipelinedOptimizer(
|
|
81
|
+
mesh_pp=self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"],
|
|
82
|
+
optimizers=optimizers
|
|
83
|
+
)
|
|
84
|
+
pipe_scheduler = PipelinedLRScheduler(
|
|
85
|
+
mesh_pp=self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"],
|
|
86
|
+
schedulers=lr_schedulers
|
|
87
|
+
)
|
|
88
|
+
return pipe_optimizer, pipe_scheduler
|