d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from .ops import fp32_to_bf16_kernel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.autotune(
|
|
9
|
+
configs=[
|
|
10
|
+
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
|
|
11
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
|
|
12
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
|
|
13
|
+
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
|
|
14
|
+
triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
|
|
15
|
+
],
|
|
16
|
+
key=["n_elements"],
|
|
17
|
+
restore_value=["p_ptr", "m_ptr", "v_ptr"]
|
|
18
|
+
)
|
|
19
|
+
@triton.jit
|
|
20
|
+
def _adamw_stochastic_bf16_kernel(
|
|
21
|
+
p_ptr: tl.tensor, # Pointer to parameters (Always BF16 -> read/write)
|
|
22
|
+
g_ptr: tl.tensor, # Pointer to gradients (BF16 or FP32 -> read only)
|
|
23
|
+
m_ptr: tl.tensor, # Pointer to exp_avg (BF16 or FP32 -> read/write)
|
|
24
|
+
v_ptr: tl.tensor, # Pointer to exp_avg_sq (BF16 or FP32 -> read/write)
|
|
25
|
+
n_elements: int, # Total number of elements
|
|
26
|
+
lr: float, # Learning rate
|
|
27
|
+
beta1: float,
|
|
28
|
+
beta2: float,
|
|
29
|
+
eps: float,
|
|
30
|
+
weight_decay: float,
|
|
31
|
+
step: int, # Current step (for bias correction)
|
|
32
|
+
seed: int, # Random seed for stochastic rounding
|
|
33
|
+
BLOCK_SIZE: tl.constexpr,
|
|
34
|
+
GRAD_IS_BF16: tl.constexpr, # noqa: N803
|
|
35
|
+
STATE_IS_BF16: tl.constexpr # noqa: N803
|
|
36
|
+
):
|
|
37
|
+
pid = tl.program_id(axis=0)
|
|
38
|
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
39
|
+
mask = offsets < n_elements
|
|
40
|
+
|
|
41
|
+
# load parameters
|
|
42
|
+
p_bf16 = tl.load(p_ptr + offsets, mask=mask)
|
|
43
|
+
p_fp32 = p_bf16.to(tl.float32)
|
|
44
|
+
|
|
45
|
+
# load grad
|
|
46
|
+
if GRAD_IS_BF16:
|
|
47
|
+
g_fp32 = tl.load(g_ptr + offsets, mask=mask).to(tl.float32)
|
|
48
|
+
else:
|
|
49
|
+
g_fp32 = tl.load(g_ptr + offsets, mask=mask)
|
|
50
|
+
|
|
51
|
+
# load states
|
|
52
|
+
if STATE_IS_BF16:
|
|
53
|
+
m_curr = tl.load(m_ptr + offsets, mask=mask).to(tl.float32)
|
|
54
|
+
v_curr = tl.load(v_ptr + offsets, mask=mask).to(tl.float32)
|
|
55
|
+
else:
|
|
56
|
+
m_curr = tl.load(m_ptr + offsets, mask=mask)
|
|
57
|
+
v_curr = tl.load(v_ptr + offsets, mask=mask)
|
|
58
|
+
|
|
59
|
+
# now the math goes in fp32
|
|
60
|
+
|
|
61
|
+
# do weight decay
|
|
62
|
+
p_fp32 = p_fp32 * (1.0 - lr * weight_decay)
|
|
63
|
+
|
|
64
|
+
# update moments
|
|
65
|
+
m_next = beta1 * m_curr + (1.0 - beta1) * g_fp32
|
|
66
|
+
v_next = beta2 * v_curr + (1.0 - beta2) * (g_fp32 * g_fp32)
|
|
67
|
+
|
|
68
|
+
# bias correction
|
|
69
|
+
bias_correction1 = 1.0 - tl.exp(step * tl.log(beta1))
|
|
70
|
+
bias_correction2 = 1.0 - tl.exp(step * tl.log(beta2))
|
|
71
|
+
|
|
72
|
+
m_hat = m_next / bias_correction1
|
|
73
|
+
v_hat = v_next / bias_correction2
|
|
74
|
+
|
|
75
|
+
# compute update
|
|
76
|
+
update = (lr * m_hat) / (tl.sqrt(v_hat) + eps)
|
|
77
|
+
|
|
78
|
+
p_new_fp32 = p_fp32 - update
|
|
79
|
+
|
|
80
|
+
# and now we store...
|
|
81
|
+
# p -> always stochastic fp32 -> bf16
|
|
82
|
+
# states -> depending on constexprs
|
|
83
|
+
p_new_bf16 = fp32_to_bf16_kernel(p_new_fp32, offsets, seed)
|
|
84
|
+
tl.store(p_ptr + offsets, p_new_bf16, mask=mask)
|
|
85
|
+
|
|
86
|
+
if STATE_IS_BF16:
|
|
87
|
+
m_next_bf16 = fp32_to_bf16_kernel(m_next, offsets, seed + 42)
|
|
88
|
+
v_next_bf16 = fp32_to_bf16_kernel(v_next, offsets, seed + 67)
|
|
89
|
+
|
|
90
|
+
tl.store(m_ptr + offsets, m_next_bf16, mask=mask)
|
|
91
|
+
tl.store(v_ptr + offsets, v_next_bf16, mask=mask)
|
|
92
|
+
else:
|
|
93
|
+
tl.store(m_ptr + offsets, m_next, mask=mask)
|
|
94
|
+
tl.store(v_ptr + offsets, v_next, mask=mask)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def adamw_stochastic_bf16_( # noqa: C901
|
|
98
|
+
params: torch.Tensor,
|
|
99
|
+
grads: torch.Tensor,
|
|
100
|
+
exp_avg: torch.Tensor,
|
|
101
|
+
exp_avg_sq: torch.Tensor,
|
|
102
|
+
lr: float,
|
|
103
|
+
beta1: float,
|
|
104
|
+
beta2: float,
|
|
105
|
+
eps: float,
|
|
106
|
+
weight_decay: float,
|
|
107
|
+
step: int,
|
|
108
|
+
generator: torch.Generator | None = None
|
|
109
|
+
) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Performs a single in-place AdamW optimization step.
|
|
112
|
+
|
|
113
|
+
It is specifically designed for scenarios where parameters are stored in BFloat16.
|
|
114
|
+
|
|
115
|
+
To mitigate precision loss during the parameter update, it utilizes stochastic rounding when casting
|
|
116
|
+
FP32 calculation results back to BFloat16.
|
|
117
|
+
|
|
118
|
+
This function supports mixed precision for gradients and optimizer states (they can be
|
|
119
|
+
either FP32 or BFloat16).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
params: The tensor of model parameters to update. Must be BFloat16 and contiguous.
|
|
123
|
+
grads: The gradient tensor.
|
|
124
|
+
exp_avg: The exponential moving average of gradient values (first moment).
|
|
125
|
+
exp_avg_sq: The exponential moving average of squared gradient values (second moment).
|
|
126
|
+
lr: The learning rate.
|
|
127
|
+
beta1: Decay rate for the first moment estimate.
|
|
128
|
+
beta2: Decay rate for the second moment estimate.
|
|
129
|
+
eps: Term added to the denominator to improve numerical stability.
|
|
130
|
+
weight_decay: Weight decay coefficient.
|
|
131
|
+
step: The current optimization step count, used for bias correction.
|
|
132
|
+
generator: PyTorch random number generator used to create the seed for stochastic rounding.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
ValueError: If main parameters are not BFloat16, if input tensor shapes do not match,
|
|
136
|
+
if input tensors are not contiguous (for those that require in-place modification),
|
|
137
|
+
if the optimizer states (exp_avg, exp_avg_sq) have different dtypes.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
# check shape equality
|
|
141
|
+
if grads.shape != params.shape:
|
|
142
|
+
raise ValueError("Shape mismatch between grads and params.")
|
|
143
|
+
|
|
144
|
+
if exp_avg.shape != params.shape:
|
|
145
|
+
raise ValueError("Shape mismatch between exp_avg state and params.")
|
|
146
|
+
|
|
147
|
+
if exp_avg_sq.shape != params.shape:
|
|
148
|
+
raise ValueError("Shape mismatch between exp_avg_sq state and params.")
|
|
149
|
+
|
|
150
|
+
# check params
|
|
151
|
+
if params.dtype != torch.bfloat16:
|
|
152
|
+
raise ValueError("Params must be BFloat16 for this kernel.")
|
|
153
|
+
|
|
154
|
+
if not params.is_contiguous():
|
|
155
|
+
raise ValueError("Params must be contiguous since it is an in-place kernel.")
|
|
156
|
+
|
|
157
|
+
# check grads
|
|
158
|
+
if not grads.is_contiguous():
|
|
159
|
+
grads = grads.contiguous()
|
|
160
|
+
|
|
161
|
+
# check states
|
|
162
|
+
if not exp_avg.is_contiguous():
|
|
163
|
+
raise ValueError("Exp_avg state must be contiguous since it is an in-place kernel.")
|
|
164
|
+
|
|
165
|
+
if not exp_avg_sq.is_contiguous():
|
|
166
|
+
raise ValueError("Exp_avg_sq state must be contiguous since it is an in-place kernel.")
|
|
167
|
+
|
|
168
|
+
if exp_avg.dtype != exp_avg_sq.dtype:
|
|
169
|
+
raise ValueError("States have different dtypes.")
|
|
170
|
+
|
|
171
|
+
n_elements = params.numel()
|
|
172
|
+
|
|
173
|
+
grad_is_bf16 = (grads.dtype == torch.bfloat16)
|
|
174
|
+
state_is_bf16 = (exp_avg.dtype == torch.bfloat16)
|
|
175
|
+
|
|
176
|
+
# Generate random seed
|
|
177
|
+
seed = torch.randint(
|
|
178
|
+
0, 2 ** 31 - 1, (1,),
|
|
179
|
+
device="cpu",
|
|
180
|
+
generator=generator
|
|
181
|
+
).item()
|
|
182
|
+
|
|
183
|
+
def _grid(meta: dict[str, int]) -> tuple[int, ...]:
|
|
184
|
+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
185
|
+
|
|
186
|
+
_adamw_stochastic_bf16_kernel[_grid](
|
|
187
|
+
params,
|
|
188
|
+
grads,
|
|
189
|
+
exp_avg,
|
|
190
|
+
exp_avg_sq,
|
|
191
|
+
|
|
192
|
+
n_elements,
|
|
193
|
+
|
|
194
|
+
lr,
|
|
195
|
+
beta1,
|
|
196
|
+
beta2,
|
|
197
|
+
eps,
|
|
198
|
+
weight_decay,
|
|
199
|
+
step,
|
|
200
|
+
seed,
|
|
201
|
+
|
|
202
|
+
GRAD_IS_BF16=grad_is_bf16,
|
|
203
|
+
STATE_IS_BF16=state_is_bf16
|
|
204
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from .ops import fp32_to_bf16_kernel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.autotune(
|
|
9
|
+
configs=[
|
|
10
|
+
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
|
|
11
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
|
|
12
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
|
|
13
|
+
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
|
|
14
|
+
triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
|
|
15
|
+
],
|
|
16
|
+
key=["n_elements"]
|
|
17
|
+
)
|
|
18
|
+
@triton.jit
|
|
19
|
+
def _copy_fp32_to_bf16_kernel(
|
|
20
|
+
source_ptr: torch.Tensor,
|
|
21
|
+
target_ptr: torch.Tensor,
|
|
22
|
+
n_elements: int,
|
|
23
|
+
seed: int,
|
|
24
|
+
BLOCK_SIZE: tl.constexpr
|
|
25
|
+
):
|
|
26
|
+
pid = tl.program_id(axis=0)
|
|
27
|
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
28
|
+
mask = offsets < n_elements
|
|
29
|
+
|
|
30
|
+
# load source value (fp32)
|
|
31
|
+
val_fp32 = tl.load(source_ptr + offsets, mask=mask)
|
|
32
|
+
|
|
33
|
+
val_bf16 = fp32_to_bf16_kernel(
|
|
34
|
+
val_fp32=val_fp32,
|
|
35
|
+
offsets=offsets,
|
|
36
|
+
seed=seed
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
tl.store(target_ptr + offsets, val_bf16, mask=mask)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def copy_fp32_to_bf16_stochastic_(
|
|
43
|
+
target: torch.Tensor,
|
|
44
|
+
source: torch.Tensor,
|
|
45
|
+
generator: torch.Generator | None = None
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
"""
|
|
48
|
+
Copies elements from a Float32 tensor to a BFloat16 tensor using stochastic rounding.
|
|
49
|
+
|
|
50
|
+
Unlike standard round-to-nearest casting, stochastic rounding probabilistically rounds
|
|
51
|
+
numbers up or down based on the value of the bits being truncated. This preserves the
|
|
52
|
+
expected value of the tensor (E[round(x)] = x), which is crucial for accumulating
|
|
53
|
+
gradients or parameters in low precision without stagnation.
|
|
54
|
+
|
|
55
|
+
This operation is performed in-place on the target tensor.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
target: The output tensor where results are written. Must be of type BFloat16
|
|
59
|
+
and contiguous.
|
|
60
|
+
source: The input tensor containing values to copy. Must be of type Float32.
|
|
61
|
+
generator: An optional PyTorch RNG generator to strictly control the random
|
|
62
|
+
noise used for rounding.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The target tensor, modified in-place.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If target is not contiguous, if source/target shapes do not match,
|
|
69
|
+
or if dtypes are not FP32 and BF16 respectively.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
if not source.is_contiguous():
|
|
73
|
+
source = source.contiguous()
|
|
74
|
+
|
|
75
|
+
if not target.is_contiguous():
|
|
76
|
+
raise ValueError("Since this is an in-place operation, target should be a contiguous tensor!")
|
|
77
|
+
|
|
78
|
+
if source.shape != target.shape:
|
|
79
|
+
raise ValueError("Source and Target Tensors are of different shapes")
|
|
80
|
+
|
|
81
|
+
if source.dtype != torch.float32:
|
|
82
|
+
raise ValueError("Source must be Float32")
|
|
83
|
+
if target.dtype != torch.bfloat16:
|
|
84
|
+
raise ValueError("Target must be BFloat16")
|
|
85
|
+
|
|
86
|
+
n_elements = source.numel()
|
|
87
|
+
|
|
88
|
+
# Generate a random seed for this specific kernel launch
|
|
89
|
+
seed = torch.randint(
|
|
90
|
+
0, 2 ** 31 - 1, (1,),
|
|
91
|
+
device="cpu",
|
|
92
|
+
generator=generator
|
|
93
|
+
).item()
|
|
94
|
+
|
|
95
|
+
def _grid(meta: dict[str, int]) -> tuple[int, ...]:
|
|
96
|
+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
97
|
+
|
|
98
|
+
_copy_fp32_to_bf16_kernel[_grid](
|
|
99
|
+
source,
|
|
100
|
+
target,
|
|
101
|
+
n_elements,
|
|
102
|
+
seed
|
|
103
|
+
)
|
|
104
|
+
return target
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import triton
|
|
2
|
+
import triton.language as tl
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@triton.jit
|
|
6
|
+
def fp32_to_bf16_kernel(
|
|
7
|
+
val_fp32: tl.tensor,
|
|
8
|
+
offsets: tl.tensor,
|
|
9
|
+
seed: int,
|
|
10
|
+
) -> tl.tensor:
|
|
11
|
+
val_ui32 = val_fp32.to(tl.uint32, bitcast=True)
|
|
12
|
+
|
|
13
|
+
# create random noise for last bits
|
|
14
|
+
rand_val = tl.randint(seed, offsets)
|
|
15
|
+
noise = rand_val.to(tl.uint32) & 0xFFFF
|
|
16
|
+
|
|
17
|
+
# add this noise (FP32)
|
|
18
|
+
val_ui32_noisy = val_ui32 + noise
|
|
19
|
+
|
|
20
|
+
# save in 16 bits
|
|
21
|
+
bf16_bits = (val_ui32_noisy >> 16).to(tl.int16)
|
|
22
|
+
return bf16_bits.to(tl.bfloat16, bitcast=True)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.autograd import Function
|
|
5
|
+
|
|
6
|
+
from .op import silu_mul_backward, silu_mul_forward
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SiLUMulFunction(Function):
|
|
10
|
+
"""
|
|
11
|
+
Autograd function for the fused silu(x)*y operation.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def forward(ctx: Any, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
ctx.save_for_backward(x, y)
|
|
17
|
+
return silu_mul_forward(x, y)
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
21
|
+
x, y = ctx.saved_tensors
|
|
22
|
+
return silu_mul_backward(grad_output, x, y)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def silu_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
"""
|
|
27
|
+
Applies the SiLU multiplication operation: SiLU(x) * y.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
x: Input tensor x.
|
|
31
|
+
y: Input tensor y.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The resulting tensor of the same shape as inputs.
|
|
35
|
+
"""
|
|
36
|
+
return SiLUMulFunction.apply(x, y)
|
d9d/kernel/swiglu/op.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.autotune(
|
|
7
|
+
configs=[
|
|
8
|
+
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
|
|
9
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
|
|
10
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
|
|
11
|
+
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
|
|
12
|
+
triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
|
|
13
|
+
],
|
|
14
|
+
key=["n_elements"]
|
|
15
|
+
)
|
|
16
|
+
@triton.jit
|
|
17
|
+
def _silu_mul_kernel(
|
|
18
|
+
x_ptr: torch.Tensor,
|
|
19
|
+
y_ptr: torch.Tensor,
|
|
20
|
+
out_ptr: torch.Tensor,
|
|
21
|
+
n_elements: int,
|
|
22
|
+
BLOCK_SIZE: tl.constexpr,
|
|
23
|
+
):
|
|
24
|
+
# prepare
|
|
25
|
+
pid = tl.program_id(axis=0)
|
|
26
|
+
block_start = pid * BLOCK_SIZE
|
|
27
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
28
|
+
mask = offsets < n_elements
|
|
29
|
+
|
|
30
|
+
# read
|
|
31
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
32
|
+
x_fp32 = x.to(tl.float32) # sigmoid wants fp32
|
|
33
|
+
y = tl.load(y_ptr + offsets, mask=mask)
|
|
34
|
+
|
|
35
|
+
# compute
|
|
36
|
+
# cast back to match with torch
|
|
37
|
+
silu_x = (x_fp32 * tl.sigmoid(x_fp32)).cast(y.dtype)
|
|
38
|
+
out = silu_x * y
|
|
39
|
+
|
|
40
|
+
# write
|
|
41
|
+
tl.store(out_ptr + offsets, out, mask=mask)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def silu_mul_forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
"""
|
|
46
|
+
Computes the forward pass of silu(x)*y using Triton.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
x: Input tensor x.
|
|
50
|
+
y: Input tensor y.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
The output tensor.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If inputs x and y do not match in shape or device.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if x.shape != y.shape or x.device != y.device:
|
|
60
|
+
raise ValueError("Inputs x and y must have the same shape, be on same device.")
|
|
61
|
+
|
|
62
|
+
if not x.is_contiguous():
|
|
63
|
+
x = x.contiguous()
|
|
64
|
+
if not y.is_contiguous():
|
|
65
|
+
y = y.contiguous()
|
|
66
|
+
|
|
67
|
+
n_elements = x.numel()
|
|
68
|
+
out = torch.empty_like(x)
|
|
69
|
+
|
|
70
|
+
def _grid(meta: dict[str, int]) -> tuple[int, ...]:
|
|
71
|
+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
72
|
+
|
|
73
|
+
_silu_mul_kernel[_grid](
|
|
74
|
+
x, y, out,
|
|
75
|
+
n_elements
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return out
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@triton.autotune(
|
|
82
|
+
configs=[
|
|
83
|
+
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
|
|
84
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
|
|
85
|
+
triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
|
|
86
|
+
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
|
|
87
|
+
triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
|
|
88
|
+
],
|
|
89
|
+
key=["n_elements"]
|
|
90
|
+
)
|
|
91
|
+
@triton.jit
|
|
92
|
+
def _silu_mul_backward_kernel(
|
|
93
|
+
grad_out_ptr: torch.Tensor,
|
|
94
|
+
x_ptr: torch.Tensor,
|
|
95
|
+
y_ptr: torch.Tensor,
|
|
96
|
+
grad_x_ptr: torch.Tensor,
|
|
97
|
+
grad_y_ptr: torch.Tensor,
|
|
98
|
+
n_elements: int,
|
|
99
|
+
BLOCK_SIZE: tl.constexpr
|
|
100
|
+
):
|
|
101
|
+
# prepare
|
|
102
|
+
pid = tl.program_id(0)
|
|
103
|
+
block_start = pid * BLOCK_SIZE
|
|
104
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
105
|
+
mask = offsets < n_elements
|
|
106
|
+
|
|
107
|
+
# read
|
|
108
|
+
dout = tl.load(grad_out_ptr + offsets, mask=mask)
|
|
109
|
+
x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32) # sigmoid wants fp32
|
|
110
|
+
y = tl.load(y_ptr + offsets, mask=mask)
|
|
111
|
+
|
|
112
|
+
# Recompute Silu components
|
|
113
|
+
sig_x = tl.sigmoid(x)
|
|
114
|
+
silu_x = x * sig_x
|
|
115
|
+
|
|
116
|
+
# Compute grad_y
|
|
117
|
+
# dy = dout * silu(x)
|
|
118
|
+
dx_silu_x = dout * silu_x # Reuse this variable name logic
|
|
119
|
+
tl.store(grad_y_ptr + offsets, dx_silu_x, mask=mask)
|
|
120
|
+
|
|
121
|
+
# Compute grad_x
|
|
122
|
+
# silu'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
|
123
|
+
# = sigmoid(x) + silu(x) * (1 - sigmoid(x))
|
|
124
|
+
d_silu = sig_x + silu_x * (1.0 - sig_x)
|
|
125
|
+
|
|
126
|
+
# dx = dout * y * silu'(x)
|
|
127
|
+
dx = dout * y * d_silu
|
|
128
|
+
tl.store(grad_x_ptr + offsets, dx, mask=mask)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def silu_mul_backward(
|
|
132
|
+
grad_output: torch.Tensor, x: torch.Tensor, y: torch.Tensor
|
|
133
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
134
|
+
"""
|
|
135
|
+
Computes the backward pass of silu(x)*y using Triton.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
grad_output: Gradient of the loss with respect to the output.
|
|
139
|
+
x: Original input tensor x.
|
|
140
|
+
y: Original input tensor y.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
A tuple of (grad_x, grad_y).
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
if not grad_output.is_contiguous():
|
|
147
|
+
grad_output = grad_output.contiguous()
|
|
148
|
+
if not x.is_contiguous():
|
|
149
|
+
x = x.contiguous()
|
|
150
|
+
if not y.is_contiguous():
|
|
151
|
+
y = y.contiguous()
|
|
152
|
+
|
|
153
|
+
n_elements = x.numel()
|
|
154
|
+
|
|
155
|
+
grad_x = torch.empty_like(x)
|
|
156
|
+
grad_y = torch.empty_like(y)
|
|
157
|
+
|
|
158
|
+
def _grid(meta: dict[str, int]) -> tuple[int, ...]:
|
|
159
|
+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
160
|
+
|
|
161
|
+
_silu_mul_backward_kernel[_grid](
|
|
162
|
+
grad_output, x, y,
|
|
163
|
+
grad_x, grad_y,
|
|
164
|
+
n_elements
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return grad_x, grad_y
|
d9d/loop/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .auto_lr_scheduler import AutoLRSchedulerConfig, AutoLRSchedulerProvider
|
|
2
|
+
from .auto_optimizer import AutoOptimizerConfig, AutoOptimizerProvider
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"AutoLRSchedulerConfig",
|
|
6
|
+
"AutoLRSchedulerProvider",
|
|
7
|
+
"AutoOptimizerConfig",
|
|
8
|
+
"AutoOptimizerProvider"
|
|
9
|
+
]
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from d9d.core.protocol import LRSchedulerProtocol
|
|
6
|
+
from d9d.loop.control import InitializeLRSchedulerContext, LRSchedulerProvider
|
|
7
|
+
from d9d.lr_scheduler.piecewise import PiecewiseSchedulerConfig, piecewise_scheduler_from_config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PiecewiseConfig(BaseModel):
|
|
11
|
+
"""
|
|
12
|
+
Configuration for the piecewise learning rate scheduler.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
name: Discriminator tag, must be "piecewise".
|
|
16
|
+
scheduler: Detailed configuration for the piecewise schedule.
|
|
17
|
+
"""
|
|
18
|
+
name: Literal["piecewise"] = "piecewise"
|
|
19
|
+
|
|
20
|
+
scheduler: PiecewiseSchedulerConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
AutoLRSchedulerConfig = Annotated[
|
|
24
|
+
PiecewiseConfig,
|
|
25
|
+
Field(discriminator="name")
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AutoLRSchedulerProvider(LRSchedulerProvider):
|
|
30
|
+
"""
|
|
31
|
+
LRSchedulerProvider that builds a learning rate scheduler based on a configuration object.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: AutoLRSchedulerConfig):
|
|
35
|
+
"""Constructs the AutoLRSchedulerProvider object."""
|
|
36
|
+
|
|
37
|
+
self._config = config
|
|
38
|
+
|
|
39
|
+
def __call__(self, context: InitializeLRSchedulerContext) -> LRSchedulerProtocol:
|
|
40
|
+
match self._config:
|
|
41
|
+
case PiecewiseConfig():
|
|
42
|
+
return piecewise_scheduler_from_config(
|
|
43
|
+
self._config.scheduler,
|
|
44
|
+
optimizer=context.optimizer,
|
|
45
|
+
total_steps=context.total_steps
|
|
46
|
+
)
|