d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,11 @@
1
+ """
2
+ Utilities for stochastic type casting (e.g., FP32 to BF16).
3
+ """
4
+
5
+ from .adamw_step import adamw_stochastic_bf16_
6
+ from .copy import copy_fp32_to_bf16_stochastic_
7
+
8
+ __all__ = [
9
+ "adamw_stochastic_bf16_",
10
+ "copy_fp32_to_bf16_stochastic_"
11
+ ]
@@ -0,0 +1,204 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from .ops import fp32_to_bf16_kernel
6
+
7
+
8
+ @triton.autotune(
9
+ configs=[
10
+ triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
11
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
12
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
13
+ triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
14
+ triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
15
+ ],
16
+ key=["n_elements"],
17
+ restore_value=["p_ptr", "m_ptr", "v_ptr"]
18
+ )
19
+ @triton.jit
20
+ def _adamw_stochastic_bf16_kernel(
21
+ p_ptr: tl.tensor, # Pointer to parameters (Always BF16 -> read/write)
22
+ g_ptr: tl.tensor, # Pointer to gradients (BF16 or FP32 -> read only)
23
+ m_ptr: tl.tensor, # Pointer to exp_avg (BF16 or FP32 -> read/write)
24
+ v_ptr: tl.tensor, # Pointer to exp_avg_sq (BF16 or FP32 -> read/write)
25
+ n_elements: int, # Total number of elements
26
+ lr: float, # Learning rate
27
+ beta1: float,
28
+ beta2: float,
29
+ eps: float,
30
+ weight_decay: float,
31
+ step: int, # Current step (for bias correction)
32
+ seed: int, # Random seed for stochastic rounding
33
+ BLOCK_SIZE: tl.constexpr,
34
+ GRAD_IS_BF16: tl.constexpr, # noqa: N803
35
+ STATE_IS_BF16: tl.constexpr # noqa: N803
36
+ ):
37
+ pid = tl.program_id(axis=0)
38
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
39
+ mask = offsets < n_elements
40
+
41
+ # load parameters
42
+ p_bf16 = tl.load(p_ptr + offsets, mask=mask)
43
+ p_fp32 = p_bf16.to(tl.float32)
44
+
45
+ # load grad
46
+ if GRAD_IS_BF16:
47
+ g_fp32 = tl.load(g_ptr + offsets, mask=mask).to(tl.float32)
48
+ else:
49
+ g_fp32 = tl.load(g_ptr + offsets, mask=mask)
50
+
51
+ # load states
52
+ if STATE_IS_BF16:
53
+ m_curr = tl.load(m_ptr + offsets, mask=mask).to(tl.float32)
54
+ v_curr = tl.load(v_ptr + offsets, mask=mask).to(tl.float32)
55
+ else:
56
+ m_curr = tl.load(m_ptr + offsets, mask=mask)
57
+ v_curr = tl.load(v_ptr + offsets, mask=mask)
58
+
59
+ # now the math goes in fp32
60
+
61
+ # do weight decay
62
+ p_fp32 = p_fp32 * (1.0 - lr * weight_decay)
63
+
64
+ # update moments
65
+ m_next = beta1 * m_curr + (1.0 - beta1) * g_fp32
66
+ v_next = beta2 * v_curr + (1.0 - beta2) * (g_fp32 * g_fp32)
67
+
68
+ # bias correction
69
+ bias_correction1 = 1.0 - tl.exp(step * tl.log(beta1))
70
+ bias_correction2 = 1.0 - tl.exp(step * tl.log(beta2))
71
+
72
+ m_hat = m_next / bias_correction1
73
+ v_hat = v_next / bias_correction2
74
+
75
+ # compute update
76
+ update = (lr * m_hat) / (tl.sqrt(v_hat) + eps)
77
+
78
+ p_new_fp32 = p_fp32 - update
79
+
80
+ # and now we store...
81
+ # p -> always stochastic fp32 -> bf16
82
+ # states -> depending on constexprs
83
+ p_new_bf16 = fp32_to_bf16_kernel(p_new_fp32, offsets, seed)
84
+ tl.store(p_ptr + offsets, p_new_bf16, mask=mask)
85
+
86
+ if STATE_IS_BF16:
87
+ m_next_bf16 = fp32_to_bf16_kernel(m_next, offsets, seed + 42)
88
+ v_next_bf16 = fp32_to_bf16_kernel(v_next, offsets, seed + 67)
89
+
90
+ tl.store(m_ptr + offsets, m_next_bf16, mask=mask)
91
+ tl.store(v_ptr + offsets, v_next_bf16, mask=mask)
92
+ else:
93
+ tl.store(m_ptr + offsets, m_next, mask=mask)
94
+ tl.store(v_ptr + offsets, v_next, mask=mask)
95
+
96
+
97
+ def adamw_stochastic_bf16_( # noqa: C901
98
+ params: torch.Tensor,
99
+ grads: torch.Tensor,
100
+ exp_avg: torch.Tensor,
101
+ exp_avg_sq: torch.Tensor,
102
+ lr: float,
103
+ beta1: float,
104
+ beta2: float,
105
+ eps: float,
106
+ weight_decay: float,
107
+ step: int,
108
+ generator: torch.Generator | None = None
109
+ ) -> None:
110
+ """
111
+ Performs a single in-place AdamW optimization step.
112
+
113
+ It is specifically designed for scenarios where parameters are stored in BFloat16.
114
+
115
+ To mitigate precision loss during the parameter update, it utilizes stochastic rounding when casting
116
+ FP32 calculation results back to BFloat16.
117
+
118
+ This function supports mixed precision for gradients and optimizer states (they can be
119
+ either FP32 or BFloat16).
120
+
121
+ Args:
122
+ params: The tensor of model parameters to update. Must be BFloat16 and contiguous.
123
+ grads: The gradient tensor.
124
+ exp_avg: The exponential moving average of gradient values (first moment).
125
+ exp_avg_sq: The exponential moving average of squared gradient values (second moment).
126
+ lr: The learning rate.
127
+ beta1: Decay rate for the first moment estimate.
128
+ beta2: Decay rate for the second moment estimate.
129
+ eps: Term added to the denominator to improve numerical stability.
130
+ weight_decay: Weight decay coefficient.
131
+ step: The current optimization step count, used for bias correction.
132
+ generator: PyTorch random number generator used to create the seed for stochastic rounding.
133
+
134
+ Raises:
135
+ ValueError: If main parameters are not BFloat16, if input tensor shapes do not match,
136
+ if input tensors are not contiguous (for those that require in-place modification),
137
+ if the optimizer states (exp_avg, exp_avg_sq) have different dtypes.
138
+ """
139
+
140
+ # check shape equality
141
+ if grads.shape != params.shape:
142
+ raise ValueError("Shape mismatch between grads and params.")
143
+
144
+ if exp_avg.shape != params.shape:
145
+ raise ValueError("Shape mismatch between exp_avg state and params.")
146
+
147
+ if exp_avg_sq.shape != params.shape:
148
+ raise ValueError("Shape mismatch between exp_avg_sq state and params.")
149
+
150
+ # check params
151
+ if params.dtype != torch.bfloat16:
152
+ raise ValueError("Params must be BFloat16 for this kernel.")
153
+
154
+ if not params.is_contiguous():
155
+ raise ValueError("Params must be contiguous since it is an in-place kernel.")
156
+
157
+ # check grads
158
+ if not grads.is_contiguous():
159
+ grads = grads.contiguous()
160
+
161
+ # check states
162
+ if not exp_avg.is_contiguous():
163
+ raise ValueError("Exp_avg state must be contiguous since it is an in-place kernel.")
164
+
165
+ if not exp_avg_sq.is_contiguous():
166
+ raise ValueError("Exp_avg_sq state must be contiguous since it is an in-place kernel.")
167
+
168
+ if exp_avg.dtype != exp_avg_sq.dtype:
169
+ raise ValueError("States have different dtypes.")
170
+
171
+ n_elements = params.numel()
172
+
173
+ grad_is_bf16 = (grads.dtype == torch.bfloat16)
174
+ state_is_bf16 = (exp_avg.dtype == torch.bfloat16)
175
+
176
+ # Generate random seed
177
+ seed = torch.randint(
178
+ 0, 2 ** 31 - 1, (1,),
179
+ device="cpu",
180
+ generator=generator
181
+ ).item()
182
+
183
+ def _grid(meta: dict[str, int]) -> tuple[int, ...]:
184
+ return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
185
+
186
+ _adamw_stochastic_bf16_kernel[_grid](
187
+ params,
188
+ grads,
189
+ exp_avg,
190
+ exp_avg_sq,
191
+
192
+ n_elements,
193
+
194
+ lr,
195
+ beta1,
196
+ beta2,
197
+ eps,
198
+ weight_decay,
199
+ step,
200
+ seed,
201
+
202
+ GRAD_IS_BF16=grad_is_bf16,
203
+ STATE_IS_BF16=state_is_bf16
204
+ )
@@ -0,0 +1,104 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from .ops import fp32_to_bf16_kernel
6
+
7
+
8
+ @triton.autotune(
9
+ configs=[
10
+ triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
11
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
12
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
13
+ triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
14
+ triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
15
+ ],
16
+ key=["n_elements"]
17
+ )
18
+ @triton.jit
19
+ def _copy_fp32_to_bf16_kernel(
20
+ source_ptr: torch.Tensor,
21
+ target_ptr: torch.Tensor,
22
+ n_elements: int,
23
+ seed: int,
24
+ BLOCK_SIZE: tl.constexpr
25
+ ):
26
+ pid = tl.program_id(axis=0)
27
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28
+ mask = offsets < n_elements
29
+
30
+ # load source value (fp32)
31
+ val_fp32 = tl.load(source_ptr + offsets, mask=mask)
32
+
33
+ val_bf16 = fp32_to_bf16_kernel(
34
+ val_fp32=val_fp32,
35
+ offsets=offsets,
36
+ seed=seed
37
+ )
38
+
39
+ tl.store(target_ptr + offsets, val_bf16, mask=mask)
40
+
41
+
42
+ def copy_fp32_to_bf16_stochastic_(
43
+ target: torch.Tensor,
44
+ source: torch.Tensor,
45
+ generator: torch.Generator | None = None
46
+ ) -> torch.Tensor:
47
+ """
48
+ Copies elements from a Float32 tensor to a BFloat16 tensor using stochastic rounding.
49
+
50
+ Unlike standard round-to-nearest casting, stochastic rounding probabilistically rounds
51
+ numbers up or down based on the value of the bits being truncated. This preserves the
52
+ expected value of the tensor (E[round(x)] = x), which is crucial for accumulating
53
+ gradients or parameters in low precision without stagnation.
54
+
55
+ This operation is performed in-place on the target tensor.
56
+
57
+ Args:
58
+ target: The output tensor where results are written. Must be of type BFloat16
59
+ and contiguous.
60
+ source: The input tensor containing values to copy. Must be of type Float32.
61
+ generator: An optional PyTorch RNG generator to strictly control the random
62
+ noise used for rounding.
63
+
64
+ Returns:
65
+ The target tensor, modified in-place.
66
+
67
+ Raises:
68
+ ValueError: If target is not contiguous, if source/target shapes do not match,
69
+ or if dtypes are not FP32 and BF16 respectively.
70
+ """
71
+
72
+ if not source.is_contiguous():
73
+ source = source.contiguous()
74
+
75
+ if not target.is_contiguous():
76
+ raise ValueError("Since this is an in-place operation, target should be a contiguous tensor!")
77
+
78
+ if source.shape != target.shape:
79
+ raise ValueError("Source and Target Tensors are of different shapes")
80
+
81
+ if source.dtype != torch.float32:
82
+ raise ValueError("Source must be Float32")
83
+ if target.dtype != torch.bfloat16:
84
+ raise ValueError("Target must be BFloat16")
85
+
86
+ n_elements = source.numel()
87
+
88
+ # Generate a random seed for this specific kernel launch
89
+ seed = torch.randint(
90
+ 0, 2 ** 31 - 1, (1,),
91
+ device="cpu",
92
+ generator=generator
93
+ ).item()
94
+
95
+ def _grid(meta: dict[str, int]) -> tuple[int, ...]:
96
+ return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
97
+
98
+ _copy_fp32_to_bf16_kernel[_grid](
99
+ source,
100
+ target,
101
+ n_elements,
102
+ seed
103
+ )
104
+ return target
@@ -0,0 +1,5 @@
1
+ from .round import fp32_to_bf16_kernel
2
+
3
+ __all__ = [
4
+ "fp32_to_bf16_kernel"
5
+ ]
@@ -0,0 +1,22 @@
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def fp32_to_bf16_kernel(
7
+ val_fp32: tl.tensor,
8
+ offsets: tl.tensor,
9
+ seed: int,
10
+ ) -> tl.tensor:
11
+ val_ui32 = val_fp32.to(tl.uint32, bitcast=True)
12
+
13
+ # create random noise for last bits
14
+ rand_val = tl.randint(seed, offsets)
15
+ noise = rand_val.to(tl.uint32) & 0xFFFF
16
+
17
+ # add this noise (FP32)
18
+ val_ui32_noisy = val_ui32 + noise
19
+
20
+ # save in 16 bits
21
+ bf16_bits = (val_ui32_noisy >> 16).to(tl.int16)
22
+ return bf16_bits.to(tl.bfloat16, bitcast=True)
@@ -0,0 +1,5 @@
1
+ from .function import silu_mul
2
+
3
+ __all__ = [
4
+ "silu_mul"
5
+ ]
@@ -0,0 +1,36 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+
6
+ from .op import silu_mul_backward, silu_mul_forward
7
+
8
+
9
+ class SiLUMulFunction(Function):
10
+ """
11
+ Autograd function for the fused silu(x)*y operation.
12
+ """
13
+
14
+ @staticmethod
15
+ def forward(ctx: Any, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
16
+ ctx.save_for_backward(x, y)
17
+ return silu_mul_forward(x, y)
18
+
19
+ @staticmethod
20
+ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
21
+ x, y = ctx.saved_tensors
22
+ return silu_mul_backward(grad_output, x, y)
23
+
24
+
25
+ def silu_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Applies the SiLU multiplication operation: SiLU(x) * y.
28
+
29
+ Args:
30
+ x: Input tensor x.
31
+ y: Input tensor y.
32
+
33
+ Returns:
34
+ The resulting tensor of the same shape as inputs.
35
+ """
36
+ return SiLUMulFunction.apply(x, y)
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.autotune(
7
+ configs=[
8
+ triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
9
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
10
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
11
+ triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
12
+ triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
13
+ ],
14
+ key=["n_elements"]
15
+ )
16
+ @triton.jit
17
+ def _silu_mul_kernel(
18
+ x_ptr: torch.Tensor,
19
+ y_ptr: torch.Tensor,
20
+ out_ptr: torch.Tensor,
21
+ n_elements: int,
22
+ BLOCK_SIZE: tl.constexpr,
23
+ ):
24
+ # prepare
25
+ pid = tl.program_id(axis=0)
26
+ block_start = pid * BLOCK_SIZE
27
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
28
+ mask = offsets < n_elements
29
+
30
+ # read
31
+ x = tl.load(x_ptr + offsets, mask=mask)
32
+ x_fp32 = x.to(tl.float32) # sigmoid wants fp32
33
+ y = tl.load(y_ptr + offsets, mask=mask)
34
+
35
+ # compute
36
+ # cast back to match with torch
37
+ silu_x = (x_fp32 * tl.sigmoid(x_fp32)).cast(y.dtype)
38
+ out = silu_x * y
39
+
40
+ # write
41
+ tl.store(out_ptr + offsets, out, mask=mask)
42
+
43
+
44
+ def silu_mul_forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Computes the forward pass of silu(x)*y using Triton.
47
+
48
+ Args:
49
+ x: Input tensor x.
50
+ y: Input tensor y.
51
+
52
+ Returns:
53
+ The output tensor.
54
+
55
+ Raises:
56
+ ValueError: If inputs x and y do not match in shape or device.
57
+ """
58
+
59
+ if x.shape != y.shape or x.device != y.device:
60
+ raise ValueError("Inputs x and y must have the same shape, be on same device.")
61
+
62
+ if not x.is_contiguous():
63
+ x = x.contiguous()
64
+ if not y.is_contiguous():
65
+ y = y.contiguous()
66
+
67
+ n_elements = x.numel()
68
+ out = torch.empty_like(x)
69
+
70
+ def _grid(meta: dict[str, int]) -> tuple[int, ...]:
71
+ return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
72
+
73
+ _silu_mul_kernel[_grid](
74
+ x, y, out,
75
+ n_elements
76
+ )
77
+
78
+ return out
79
+
80
+
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
84
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=4),
85
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
86
+ triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
87
+ triton.Config({"BLOCK_SIZE": 8192}, num_warps=8),
88
+ ],
89
+ key=["n_elements"]
90
+ )
91
+ @triton.jit
92
+ def _silu_mul_backward_kernel(
93
+ grad_out_ptr: torch.Tensor,
94
+ x_ptr: torch.Tensor,
95
+ y_ptr: torch.Tensor,
96
+ grad_x_ptr: torch.Tensor,
97
+ grad_y_ptr: torch.Tensor,
98
+ n_elements: int,
99
+ BLOCK_SIZE: tl.constexpr
100
+ ):
101
+ # prepare
102
+ pid = tl.program_id(0)
103
+ block_start = pid * BLOCK_SIZE
104
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
105
+ mask = offsets < n_elements
106
+
107
+ # read
108
+ dout = tl.load(grad_out_ptr + offsets, mask=mask)
109
+ x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32) # sigmoid wants fp32
110
+ y = tl.load(y_ptr + offsets, mask=mask)
111
+
112
+ # Recompute Silu components
113
+ sig_x = tl.sigmoid(x)
114
+ silu_x = x * sig_x
115
+
116
+ # Compute grad_y
117
+ # dy = dout * silu(x)
118
+ dx_silu_x = dout * silu_x # Reuse this variable name logic
119
+ tl.store(grad_y_ptr + offsets, dx_silu_x, mask=mask)
120
+
121
+ # Compute grad_x
122
+ # silu'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
123
+ # = sigmoid(x) + silu(x) * (1 - sigmoid(x))
124
+ d_silu = sig_x + silu_x * (1.0 - sig_x)
125
+
126
+ # dx = dout * y * silu'(x)
127
+ dx = dout * y * d_silu
128
+ tl.store(grad_x_ptr + offsets, dx, mask=mask)
129
+
130
+
131
+ def silu_mul_backward(
132
+ grad_output: torch.Tensor, x: torch.Tensor, y: torch.Tensor
133
+ ) -> tuple[torch.Tensor, torch.Tensor]:
134
+ """
135
+ Computes the backward pass of silu(x)*y using Triton.
136
+
137
+ Args:
138
+ grad_output: Gradient of the loss with respect to the output.
139
+ x: Original input tensor x.
140
+ y: Original input tensor y.
141
+
142
+ Returns:
143
+ A tuple of (grad_x, grad_y).
144
+ """
145
+
146
+ if not grad_output.is_contiguous():
147
+ grad_output = grad_output.contiguous()
148
+ if not x.is_contiguous():
149
+ x = x.contiguous()
150
+ if not y.is_contiguous():
151
+ y = y.contiguous()
152
+
153
+ n_elements = x.numel()
154
+
155
+ grad_x = torch.empty_like(x)
156
+ grad_y = torch.empty_like(y)
157
+
158
+ def _grid(meta: dict[str, int]) -> tuple[int, ...]:
159
+ return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
160
+
161
+ _silu_mul_backward_kernel[_grid](
162
+ grad_output, x, y,
163
+ grad_x, grad_y,
164
+ n_elements
165
+ )
166
+
167
+ return grad_x, grad_y
d9d/loop/__init__.py ADDED
File without changes
@@ -0,0 +1,9 @@
1
+ from .auto_lr_scheduler import AutoLRSchedulerConfig, AutoLRSchedulerProvider
2
+ from .auto_optimizer import AutoOptimizerConfig, AutoOptimizerProvider
3
+
4
+ __all__ = [
5
+ "AutoLRSchedulerConfig",
6
+ "AutoLRSchedulerProvider",
7
+ "AutoOptimizerConfig",
8
+ "AutoOptimizerProvider"
9
+ ]
@@ -0,0 +1,46 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from d9d.core.protocol import LRSchedulerProtocol
6
+ from d9d.loop.control import InitializeLRSchedulerContext, LRSchedulerProvider
7
+ from d9d.lr_scheduler.piecewise import PiecewiseSchedulerConfig, piecewise_scheduler_from_config
8
+
9
+
10
+ class PiecewiseConfig(BaseModel):
11
+ """
12
+ Configuration for the piecewise learning rate scheduler.
13
+
14
+ Attributes:
15
+ name: Discriminator tag, must be "piecewise".
16
+ scheduler: Detailed configuration for the piecewise schedule.
17
+ """
18
+ name: Literal["piecewise"] = "piecewise"
19
+
20
+ scheduler: PiecewiseSchedulerConfig
21
+
22
+
23
+ AutoLRSchedulerConfig = Annotated[
24
+ PiecewiseConfig,
25
+ Field(discriminator="name")
26
+ ]
27
+
28
+
29
+ class AutoLRSchedulerProvider(LRSchedulerProvider):
30
+ """
31
+ LRSchedulerProvider that builds a learning rate scheduler based on a configuration object.
32
+ """
33
+
34
+ def __init__(self, config: AutoLRSchedulerConfig):
35
+ """Constructs the AutoLRSchedulerProvider object."""
36
+
37
+ self._config = config
38
+
39
+ def __call__(self, context: InitializeLRSchedulerContext) -> LRSchedulerProtocol:
40
+ match self._config:
41
+ case PiecewiseConfig():
42
+ return piecewise_scheduler_from_config(
43
+ self._config.scheduler,
44
+ optimizer=context.optimizer,
45
+ total_steps=context.total_steps
46
+ )