d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, PositiveInt
|
|
4
|
+
from torch.optim import Optimizer
|
|
5
|
+
|
|
6
|
+
from d9d.core.protocol import LRSchedulerProtocol
|
|
7
|
+
|
|
8
|
+
from .builder import piecewise_schedule
|
|
9
|
+
from .curves import CurveBase, CurveCosine, CurveExponential, CurveLinear, CurvePoly
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CurveLinearConfig(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Configuration for linear interpolation.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
type: Literal["linear"] = "linear"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CurveCosineConfig(BaseModel):
|
|
21
|
+
"""
|
|
22
|
+
Configuration for cosine interpolation.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
type: Literal["cosine"] = "cosine"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CurveExponentialConfig(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
Configuration for exponential interpolation.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
type: Literal["exponential"] = "exponential"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CurvePolyConfig(BaseModel):
|
|
37
|
+
"""
|
|
38
|
+
Configuration for polynomial interpolation.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
power: The exponent of the polynomial function.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
type: Literal["poly"] = "poly"
|
|
45
|
+
power: float = 2.0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
AnyCurveConfig = Annotated[
|
|
49
|
+
CurveLinearConfig | CurveCosineConfig | CurveExponentialConfig | CurvePolyConfig,
|
|
50
|
+
Field(discriminator="type")
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def curve_from_config(config: AnyCurveConfig) -> CurveBase:
|
|
55
|
+
"""
|
|
56
|
+
Instantiates a concrete curve object from its configuration.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
config: The configuration object.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
The instantiated curve.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
match config:
|
|
66
|
+
case CurveLinearConfig():
|
|
67
|
+
return CurveLinear()
|
|
68
|
+
case CurvePolyConfig():
|
|
69
|
+
return CurvePoly(config.power)
|
|
70
|
+
case CurveExponentialConfig():
|
|
71
|
+
return CurveExponential()
|
|
72
|
+
case CurveCosineConfig():
|
|
73
|
+
return CurveCosine()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StepPhaseConfig(BaseModel):
|
|
77
|
+
"""
|
|
78
|
+
Configuration for a phase defined by a fixed number of steps.
|
|
79
|
+
|
|
80
|
+
Attributes:
|
|
81
|
+
mode: Discriminator field, must be "steps".
|
|
82
|
+
steps: The absolute duration of this phase in steps.
|
|
83
|
+
target_multiplier: The multiplier value at the end of this phase.
|
|
84
|
+
curve: The interpolation curve configuration.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
mode: Literal["steps"] = "steps"
|
|
88
|
+
|
|
89
|
+
steps: PositiveInt
|
|
90
|
+
target_multiplier: float
|
|
91
|
+
curve: AnyCurveConfig
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class PercentagePhaseConfig(BaseModel):
|
|
95
|
+
"""
|
|
96
|
+
Configuration for a phase that lasts until a specific percentage of training is complete.
|
|
97
|
+
|
|
98
|
+
Attributes:
|
|
99
|
+
mode: Discriminator field, must be "percentage".
|
|
100
|
+
percentage: The target progress (0.0 to 1.0) where this phase ends.
|
|
101
|
+
target_multiplier: The multiplier value at the end of this phase.
|
|
102
|
+
curve: The interpolation curve configuration.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
mode: Literal["percentage"] = "percentage"
|
|
106
|
+
|
|
107
|
+
percentage: float = Field(..., ge=0.0, le=1.0)
|
|
108
|
+
target_multiplier: float
|
|
109
|
+
curve: AnyCurveConfig
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class RestPhaseConfig(BaseModel):
|
|
113
|
+
"""
|
|
114
|
+
Configuration for a phase that fills the remainder of the training duration.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
mode: Discriminator field, must be "rest".
|
|
118
|
+
target_multiplier: The multiplier value at the very end of training.
|
|
119
|
+
curve: The interpolation curve configuration.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
mode: Literal["rest"] = "rest"
|
|
123
|
+
|
|
124
|
+
target_multiplier: float
|
|
125
|
+
curve: AnyCurveConfig
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
PhaseConfig = Annotated[
|
|
129
|
+
StepPhaseConfig | PercentagePhaseConfig | RestPhaseConfig,
|
|
130
|
+
Field(discriminator="mode")
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class PiecewiseSchedulerConfig(BaseModel):
|
|
135
|
+
"""
|
|
136
|
+
Declarative configuration for a piecewise learning rate scheduler.
|
|
137
|
+
|
|
138
|
+
Attributes:
|
|
139
|
+
initial_multiplier: The starting learning rate multiplier.
|
|
140
|
+
phases: A sequential list of phase configurations.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
initial_multiplier: float
|
|
144
|
+
phases: list[PhaseConfig]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def piecewise_scheduler_from_config(
|
|
148
|
+
config: PiecewiseSchedulerConfig,
|
|
149
|
+
optimizer: Optimizer,
|
|
150
|
+
total_steps: int | None
|
|
151
|
+
) -> LRSchedulerProtocol:
|
|
152
|
+
"""
|
|
153
|
+
Constructs a PyTorch scheduler from the provided configuration.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
config: The scheduler configuration.
|
|
157
|
+
optimizer: The optimizer to wrap.
|
|
158
|
+
total_steps: The total number of training steps. Required if using percentage-based phases.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
A configured learning rate scheduler.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
builder = piecewise_schedule(config.initial_multiplier, total_steps)
|
|
165
|
+
|
|
166
|
+
for phase in config.phases:
|
|
167
|
+
curve = curve_from_config(phase.curve)
|
|
168
|
+
match phase:
|
|
169
|
+
case StepPhaseConfig():
|
|
170
|
+
builder.for_steps(phase.steps, phase.target_multiplier, curve)
|
|
171
|
+
case PercentagePhaseConfig():
|
|
172
|
+
builder.until_percentage(phase.percentage, phase.target_multiplier, curve)
|
|
173
|
+
case RestPhaseConfig():
|
|
174
|
+
builder.fill_rest(phase.target_multiplier, curve)
|
|
175
|
+
|
|
176
|
+
return builder.build(optimizer)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CurveBase(abc.ABC):
|
|
6
|
+
"""
|
|
7
|
+
Abstract base class for interpolation curves used in scheduling.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
@abc.abstractmethod
|
|
11
|
+
def compute(self, start: float, end: float, step_p: float) -> float:
|
|
12
|
+
"""
|
|
13
|
+
Calculates the interpolated value.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
start: The value at the beginning of the phase.
|
|
17
|
+
end: The value at the end of the phase.
|
|
18
|
+
step_p: Progress fraction through the phase (0.0 to 1.0).
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The interpolated value.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CurveLinear(CurveBase):
|
|
26
|
+
"""
|
|
27
|
+
Linearly interpolates between start and end values.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def compute(self, start: float, end: float, step_p: float) -> float:
|
|
31
|
+
return start + (end - start) * step_p
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CurveCosine(CurveBase):
|
|
35
|
+
"""
|
|
36
|
+
Interpolates using a cosine annealing schedule (half-period cosine).
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def compute(self, start: float, end: float, step_p: float) -> float:
|
|
40
|
+
cos_out = (1 + math.cos(math.pi * step_p)) / 2
|
|
41
|
+
return end + (start - end) * cos_out
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CurvePoly(CurveBase):
|
|
45
|
+
"""
|
|
46
|
+
Interpolates using a polynomial function.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, power: float):
|
|
50
|
+
"""
|
|
51
|
+
Constructs a polynomial curve.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
power: The exponent of the polynomial. 1.0 is linear, 2.0 is quadratic, etc.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
self._power = power
|
|
58
|
+
|
|
59
|
+
def compute(self, start: float, end: float, step_p: float) -> float:
|
|
60
|
+
p_transformed = step_p ** self._power
|
|
61
|
+
return start + (end - start) * p_transformed
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class CurveExponential(CurveBase):
|
|
65
|
+
"""
|
|
66
|
+
Interpolates exponentially between start and end values (log-space linear).
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def compute(self, start: float, end: float, step_p: float) -> float:
|
|
70
|
+
eps = 1e-8
|
|
71
|
+
safe_start = max(start, eps)
|
|
72
|
+
safe_end = max(end, eps)
|
|
73
|
+
|
|
74
|
+
out_log = math.log(safe_start) + (math.log(safe_end) - math.log(safe_start)) * step_p
|
|
75
|
+
return math.exp(out_log)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
from .curves import CurveBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclasses.dataclass
|
|
7
|
+
class SchedulePhase:
|
|
8
|
+
"""
|
|
9
|
+
Data container representing a single phase in a piecewise schedule.
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
start_step: The absolute step index where this phase begins.
|
|
13
|
+
end_step: The absolute step index where this phase ends.
|
|
14
|
+
start_value: The multiplier value at start_step.
|
|
15
|
+
end_value: The multiplier value at end_step.
|
|
16
|
+
curve: The interpolation logic for this phase.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
start_step: int
|
|
20
|
+
end_step: int
|
|
21
|
+
start_value: float
|
|
22
|
+
end_value: float
|
|
23
|
+
curve: CurveBase
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PiecewiseScheduleEngine:
|
|
27
|
+
"""
|
|
28
|
+
Runtime engine that calculates multipliers based on a list of defined phases.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, phases: list[SchedulePhase]):
|
|
32
|
+
"""
|
|
33
|
+
Constructs the schedule engine.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
phases: A sequential list of schedule phases.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the phases list is empty.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
if len(phases) == 0:
|
|
43
|
+
raise ValueError("Scheduler should contain at least one phase")
|
|
44
|
+
|
|
45
|
+
self._phases = phases
|
|
46
|
+
|
|
47
|
+
def get_factor(self, step: int) -> float:
|
|
48
|
+
"""
|
|
49
|
+
Computes the learning rate multiplier for the given step.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
step: The global training step.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The calculated multiplier. If the step is outside defined phases,
|
|
56
|
+
it clamps to the nearest boundary value.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if step < 0:
|
|
60
|
+
return self._phases[0].start_value
|
|
61
|
+
|
|
62
|
+
for phase in self._phases:
|
|
63
|
+
if not (phase.start_step <= step < phase.end_step):
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
steps_in_phase = step - phase.start_step
|
|
67
|
+
phase_len = phase.end_step - phase.start_step
|
|
68
|
+
phase_progress = steps_in_phase / phase_len
|
|
69
|
+
|
|
70
|
+
return phase.curve.compute(
|
|
71
|
+
start=phase.start_value,
|
|
72
|
+
end=phase.end_value,
|
|
73
|
+
step_p=phase_progress
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return self._phases[-1].end_value
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
from torch.optim import SGD, Optimizer
|
|
5
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
6
|
+
|
|
7
|
+
SchedulerFactory = Callable[[Optimizer], LRScheduler]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _get_history(factory: SchedulerFactory, num_steps: int, init_lr: float) -> list[float]:
|
|
11
|
+
optimizer = SGD(nn.Linear(1, 1).parameters(), lr=init_lr)
|
|
12
|
+
|
|
13
|
+
scheduler = factory(optimizer)
|
|
14
|
+
|
|
15
|
+
lrs = []
|
|
16
|
+
|
|
17
|
+
for _ in range(num_steps):
|
|
18
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
19
|
+
lrs.append(current_lr)
|
|
20
|
+
scheduler.step()
|
|
21
|
+
|
|
22
|
+
return lrs
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def visualize_lr_scheduler(factory: SchedulerFactory, num_steps: int, init_lr: float = 1.0):
|
|
26
|
+
"""
|
|
27
|
+
Visualizes the learning rate schedule using Plotly.
|
|
28
|
+
|
|
29
|
+
This function simulates the training process for `num_steps` to record the LR changes
|
|
30
|
+
and generates an interactive plot.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
factory: A callable that accepts an Optimizer and returns an LRScheduler.
|
|
34
|
+
num_steps: The number of steps to simulate.
|
|
35
|
+
init_lr: The initial learning rate to set on the dummy optimizer.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ImportError: If the `plotly` library is not installed.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
import plotly.graph_objects as go # noqa: PLC0415
|
|
43
|
+
except ImportError as e:
|
|
44
|
+
raise ImportError("You have to install `plotly` dependency to use scheduler visualization") from e
|
|
45
|
+
lrs = _get_history(factory, num_steps, init_lr)
|
|
46
|
+
steps = list(range(num_steps))
|
|
47
|
+
|
|
48
|
+
fig = go.Figure()
|
|
49
|
+
|
|
50
|
+
fig.add_trace(go.Scatter(
|
|
51
|
+
x=steps,
|
|
52
|
+
y=lrs,
|
|
53
|
+
mode="lines",
|
|
54
|
+
name="Learning Rate",
|
|
55
|
+
line={"color": "#636EFA", "width": 3},
|
|
56
|
+
hovertemplate="<b>Step:</b> %{x}<br><b>LR:</b> %{y:.6f}<extra></extra>"
|
|
57
|
+
))
|
|
58
|
+
|
|
59
|
+
fig.update_layout(
|
|
60
|
+
title={
|
|
61
|
+
"text": "Scheduler",
|
|
62
|
+
"y": 0.95,
|
|
63
|
+
"x": 0.5,
|
|
64
|
+
"xanchor": "center",
|
|
65
|
+
"yanchor": "top"
|
|
66
|
+
},
|
|
67
|
+
xaxis_title="Steps",
|
|
68
|
+
yaxis_title="Learning Rate",
|
|
69
|
+
template="plotly_white",
|
|
70
|
+
hovermode="x unified",
|
|
71
|
+
height=500
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
fig.show()
|
d9d/metric/__init__.py
ADDED
d9d/metric/abc.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any, Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
6
|
+
|
|
7
|
+
from d9d.core.dist_context import DistributedContext
|
|
8
|
+
from d9d.core.types import TensorTree
|
|
9
|
+
|
|
10
|
+
TComputeResult = TypeVar("TComputeResult", bound=TensorTree)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Metric(abc.ABC, Stateful, Generic[TComputeResult]):
|
|
14
|
+
"""
|
|
15
|
+
Abstract base class for all metrics.
|
|
16
|
+
|
|
17
|
+
Metrics track statistics over time (e.g., during training) and can be synchronized
|
|
18
|
+
across distributed processes. They also support state persistence via the Stateful
|
|
19
|
+
interface.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
def update(self, *args: Any, **kwargs: Any):
|
|
24
|
+
"""
|
|
25
|
+
Updates the metric state with a new batch of data.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
*args: Positional arguments required by the specific metric implementation.
|
|
29
|
+
**kwargs: Keyword arguments required by the specific metric implementation.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def trigger_sync(self, dist_context: DistributedContext):
|
|
34
|
+
"""
|
|
35
|
+
Initiates the synchronization of the metric state across distributed processes.
|
|
36
|
+
|
|
37
|
+
This method should start the collective operations (e.g., all-reduce) required
|
|
38
|
+
to aggregate statistics, but should not block waiting for completion if possible.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
dist_context: The distributed context.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def wait_sync(self, dist_context: DistributedContext):
|
|
46
|
+
"""
|
|
47
|
+
Waits for the synchronization initiated by `trigger_sync` to complete.
|
|
48
|
+
|
|
49
|
+
After this method returns, the metric state must be fully aggregated and
|
|
50
|
+
consistent across ranks.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
dist_context: The distributed context.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def compute(self) -> TComputeResult:
|
|
58
|
+
"""
|
|
59
|
+
Computes the current value of the metric.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
The computed metric result (of type `TComputeResult`).
|
|
63
|
+
This can be a single `torch.Tensor` or `PyTree` structure (dict, list, etc.)
|
|
64
|
+
containing tensors, depending on how the subclass was typed.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
@abc.abstractmethod
|
|
68
|
+
def reset(self):
|
|
69
|
+
"""
|
|
70
|
+
Resets the internal state of the metric to the initial values.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def to(self, device: str | torch.device | int):
|
|
74
|
+
"""
|
|
75
|
+
Moves a metric state to a specified device.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
device: The device to move the metric state to.
|
|
79
|
+
"""
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from d9d.core.dist_context import DistributedContext
|
|
7
|
+
from d9d.metric import Metric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ComposeMetric(Metric[dict[str, Any]]):
|
|
11
|
+
def __init__(self, children: Mapping[str, Metric]):
|
|
12
|
+
self._children = children
|
|
13
|
+
|
|
14
|
+
def update(self, *args: Any, **kwargs: Any):
|
|
15
|
+
raise ValueError("Cannot update ComposeMetric directly - you can only update its children")
|
|
16
|
+
|
|
17
|
+
def __getitem__(self, item: str) -> Metric:
|
|
18
|
+
return self._children[item]
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def children(self) -> Mapping[str, Metric]:
|
|
22
|
+
return self._children
|
|
23
|
+
|
|
24
|
+
def trigger_sync(self, dist_context: DistributedContext):
|
|
25
|
+
for metric in self._children.values():
|
|
26
|
+
metric.trigger_sync(dist_context)
|
|
27
|
+
|
|
28
|
+
def wait_sync(self, dist_context: DistributedContext):
|
|
29
|
+
for metric in self._children.values():
|
|
30
|
+
metric.wait_sync(dist_context)
|
|
31
|
+
|
|
32
|
+
def compute(self) -> dict[str, Any]:
|
|
33
|
+
return {
|
|
34
|
+
metric_name: metric.compute()
|
|
35
|
+
for metric_name, metric in self._children.items()
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def reset(self):
|
|
39
|
+
for metric in self._children.values():
|
|
40
|
+
metric.reset()
|
|
41
|
+
|
|
42
|
+
def to(self, device: str | torch.device | int):
|
|
43
|
+
for metric in self._children.values():
|
|
44
|
+
metric.to(device)
|
|
45
|
+
|
|
46
|
+
def state_dict(self) -> dict[str, Any]:
|
|
47
|
+
return {
|
|
48
|
+
metric_name: metric.state_dict()
|
|
49
|
+
for metric_name, metric in self._children.items()
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
53
|
+
for metric_name, metric in self._children.items():
|
|
54
|
+
metric.load_state_dict(state_dict[metric_name])
|
d9d/metric/impl/mean.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.distributed as dist
|
|
5
|
+
|
|
6
|
+
from d9d.core.dist_context import DistributedContext
|
|
7
|
+
from d9d.metric import Metric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class WeightedMeanMetric(Metric[torch.Tensor]):
|
|
11
|
+
"""
|
|
12
|
+
Computes the weighted mean of values.
|
|
13
|
+
|
|
14
|
+
Tracks the sum of weighted values and the sum of weights.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
"""Constructs a WeightedMeanMetric object."""
|
|
19
|
+
|
|
20
|
+
super().__init__()
|
|
21
|
+
self._value = torch.scalar_tensor(0, dtype=torch.float32)
|
|
22
|
+
self._weight = torch.scalar_tensor(0, dtype=torch.float32)
|
|
23
|
+
|
|
24
|
+
self._is_synced = False
|
|
25
|
+
self._synced_value = torch.scalar_tensor(0, dtype=torch.float32)
|
|
26
|
+
self._synced_weight = torch.scalar_tensor(0, dtype=torch.float32)
|
|
27
|
+
|
|
28
|
+
self._handles: list[dist.Work] | None = None
|
|
29
|
+
|
|
30
|
+
def update(self, values: torch.Tensor, weights: torch.Tensor):
|
|
31
|
+
self._value += (values * weights).sum()
|
|
32
|
+
self._weight += weights.sum()
|
|
33
|
+
|
|
34
|
+
self._is_synced = False
|
|
35
|
+
|
|
36
|
+
def trigger_sync(self, dist_context: DistributedContext):
|
|
37
|
+
self._synced_value = self._value.clone()
|
|
38
|
+
self._synced_weight = self._weight.clone()
|
|
39
|
+
self._is_synced = True
|
|
40
|
+
|
|
41
|
+
self._handles = [
|
|
42
|
+
dist.all_reduce(self._synced_value, op=dist.ReduceOp.SUM, async_op=True),
|
|
43
|
+
dist.all_reduce(self._synced_weight, op=dist.ReduceOp.SUM, async_op=True)
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def wait_sync(self, dist_context: DistributedContext):
|
|
47
|
+
if self._handles is None:
|
|
48
|
+
raise RuntimeError("Sync was not triggered before")
|
|
49
|
+
|
|
50
|
+
for handle in self._handles:
|
|
51
|
+
handle.wait()
|
|
52
|
+
self._handles = None
|
|
53
|
+
|
|
54
|
+
def compute(self) -> torch.Tensor:
|
|
55
|
+
if self._is_synced:
|
|
56
|
+
return self._synced_value / self._synced_weight
|
|
57
|
+
else:
|
|
58
|
+
return self._value / self._weight
|
|
59
|
+
|
|
60
|
+
def reset(self):
|
|
61
|
+
self._value.fill_(0)
|
|
62
|
+
self._weight.fill_(0)
|
|
63
|
+
self._is_synced = False
|
|
64
|
+
self._handles = None
|
|
65
|
+
|
|
66
|
+
def to(self, device: str | torch.device | int):
|
|
67
|
+
self._weight = self._weight.to(device)
|
|
68
|
+
self._value = self._value.to(device)
|
|
69
|
+
self._synced_weight = self._synced_weight.to(device)
|
|
70
|
+
self._synced_value = self._synced_value.to(device)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def accumulated_weight(self) -> torch.Tensor:
|
|
74
|
+
"""
|
|
75
|
+
Returns the total weight accumulated so far.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Scalar tensor with total weight.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
if self._is_synced:
|
|
82
|
+
return self._synced_weight
|
|
83
|
+
|
|
84
|
+
return self._weight
|
|
85
|
+
|
|
86
|
+
def state_dict(self) -> dict[str, Any]:
|
|
87
|
+
return {
|
|
88
|
+
"value": self._value,
|
|
89
|
+
"weight": self._weight
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
93
|
+
self._value = state_dict["value"]
|
|
94
|
+
self._weight = state_dict["weight"]
|
|
File without changes
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .module_reader import load_model_state
|
|
2
|
+
from .module_writer import (
|
|
3
|
+
save_model_state,
|
|
4
|
+
save_model_state_pipeline_parallel,
|
|
5
|
+
)
|
|
6
|
+
from .reader import read_model_state
|
|
7
|
+
from .writer import (
|
|
8
|
+
write_model_state_distributed,
|
|
9
|
+
write_model_state_local,
|
|
10
|
+
write_model_state_pipeline_parallel,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"load_model_state",
|
|
15
|
+
"read_model_state",
|
|
16
|
+
"save_model_state",
|
|
17
|
+
"save_model_state_pipeline_parallel",
|
|
18
|
+
"write_model_state_distributed",
|
|
19
|
+
"write_model_state_local",
|
|
20
|
+
"write_model_state_pipeline_parallel"
|
|
21
|
+
]
|