d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,112 @@
1
+ import tarfile
2
+ import time
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+
6
+ import torch.profiler as tprof
7
+
8
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
9
+
10
+
11
+ class Profiler:
12
+ """
13
+ Manages distributed performance profiling using PyTorch Profiler.
14
+
15
+ This class wraps `torch.profiler` to provide automatic trace exporting,
16
+ compression, and file naming consistent with the distributed DeviceMesh
17
+ topology. It configures the schedule to repeat periodically based on
18
+ the provided step counts.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ save_dir: Path,
24
+ period_steps: int,
25
+ warmup_steps: int,
26
+ active_steps: int,
27
+ dist_context: DistributedContext
28
+ ):
29
+ """
30
+ Constructs a Profiler object.
31
+
32
+ Args:
33
+ save_dir: Directory where trace files will be saved.
34
+ period_steps: Total length of a profiling cycle (wait + warmup + active).
35
+ warmup_steps: Number of steps to ignore before recording to allow for warming-up.
36
+ active_steps: Number of steps to actively record traces.
37
+ dist_context: The distributed context object.
38
+ """
39
+
40
+ self._save_dir = save_dir
41
+ self._period = period_steps
42
+ self._warmup = warmup_steps
43
+ self._active = active_steps
44
+ self._dist_context = dist_context
45
+
46
+ def _get_save_file_name(self) -> str:
47
+ if self._dist_context.mesh_params.is_distributed:
48
+ mesh_regular = self._dist_context.mesh_for(REGULAR_DOMAIN)
49
+ coord = mesh_regular.get_coordinate()
50
+ if coord is None:
51
+ raise RuntimeError("Invalid mesh")
52
+ coord_str = "-".join(map(str, coord))
53
+ rank = mesh_regular.get_rank()
54
+ return f"rank-{rank}-coord-{coord_str}-trace.json"
55
+ else:
56
+ return "trace.json"
57
+
58
+ def _dump_trace(self, prof: tprof.profile):
59
+ save_dir = self._save_dir / f"step_{prof.step_num}"
60
+ save_dir.mkdir(parents=True, exist_ok=True)
61
+ save_file = save_dir / self._get_save_file_name()
62
+
63
+ begin = time.monotonic()
64
+
65
+ prof.export_chrome_trace(str(save_file))
66
+ with tarfile.open(save_file.with_suffix(".tar.gz"), "w:gz") as tar:
67
+ tar.add(save_file, arcname=save_file.name)
68
+ save_file.unlink()
69
+
70
+ end = time.monotonic()
71
+
72
+ self._dist_context.logger.info(
73
+ f"Finished dumping profiler traces in {end - begin:.2f} seconds"
74
+ )
75
+
76
+ @contextmanager
77
+ def open(self, start_step: int):
78
+ """
79
+ Opens a context manager for profiling execution.
80
+
81
+ This sets up the `torch.profiler.profile` with a schedule derived from
82
+ the initialization parameters. It captures both CPU and CUDA activities,
83
+ records shapes, and tracks stack traces.
84
+
85
+ When the schedule triggers `on_trace_ready`, the trace is automatically
86
+ exported to the `save_dir`, compressed into a `.tar.gz` file, and the
87
+ raw JSON is removed to save space.
88
+
89
+ Args:
90
+ start_step: The current global step number to initialize the
91
+ profiler state.
92
+
93
+ Yields:
94
+ The configured torch profiler instance.
95
+ """
96
+
97
+ wait = self._period - (self._active + self._warmup)
98
+ warmup = self._warmup
99
+ active = self._active
100
+
101
+ with tprof.profile(
102
+ activities=[
103
+ tprof.ProfilerActivity.CPU,
104
+ tprof.ProfilerActivity.CUDA
105
+ ],
106
+ schedule=tprof.schedule(wait=wait, warmup=warmup, active=active),
107
+ on_trace_ready=self._dump_trace,
108
+ record_shapes=True,
109
+ with_stack=True
110
+ ) as profiler:
111
+ profiler.step_num = start_step
112
+ yield profiler
@@ -0,0 +1,6 @@
1
+ from .main_process import load_state_dict_main_process, state_dict_main_process
2
+
3
+ __all__ = [
4
+ "load_state_dict_main_process",
5
+ "state_dict_main_process"
6
+ ]
@@ -0,0 +1,44 @@
1
+ from typing import Any
2
+
3
+ from torch.distributed.checkpoint.stateful import Stateful
4
+
5
+ from d9d.core.dist_context import DistributedContext
6
+
7
+
8
+ def state_dict_main_process(dist_context: DistributedContext, obj: Stateful) -> dict[str, Any]:
9
+ """
10
+ Retrieves the state dictionary of an object only on the main process.
11
+
12
+ This is useful for checkpointing components that track global state primarily
13
+ managed by the driver/main rank, ensuring that non-main ranks return an empty
14
+ state to avoid duplication or synchronization issues during checkpointing.
15
+
16
+ Args:
17
+ dist_context: The distributed context to check for main process status.
18
+ obj: The stateful object to serialize.
19
+
20
+ Returns:
21
+ A dictionary containing the object's state under the 'main_process' key on
22
+ the main rank, and an empty dictionary on all other ranks.
23
+ """
24
+
25
+ if dist_context.is_main_process:
26
+ return {
27
+ "main_process": obj.state_dict()
28
+ }
29
+ else:
30
+ return {}
31
+
32
+
33
+ def load_state_dict_main_process(dist_context: DistributedContext, obj: Stateful, state_dict: dict[str, Any]):
34
+ """
35
+ Restores the state dictionary of an object only on the main process.
36
+
37
+ Args:
38
+ dist_context: The distributed context to check for main process status.
39
+ obj: The stateful object to restore.
40
+ state_dict: The state dictionary created by "state_dict_main_process" function.
41
+ """
42
+
43
+ if dist_context.is_main_process:
44
+ obj.load_state_dict(state_dict["main_process"])
d9d/kernel/__init__.py ADDED
File without changes
@@ -0,0 +1,5 @@
1
+ from .main import linear_cross_entropy
2
+
3
+ __all__ = [
4
+ "linear_cross_entropy"
5
+ ]
d9d/kernel/cce/cce.py ADDED
@@ -0,0 +1,298 @@
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ from dataclasses import dataclass
3
+ from typing import cast
4
+
5
+ import torch
6
+ import torch.amp
7
+
8
+ from cut_cross_entropy.cce_backward import cce_backward_kernel
9
+ from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel
10
+ from cut_cross_entropy.constants import IGNORE_INDEX
11
+ from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start
12
+ from cut_cross_entropy.utils import (
13
+ TensorInfo,
14
+ _build_flat_valids,
15
+ _handle_eps,
16
+ handle_reduction_none,
17
+ )
18
+ from cut_cross_entropy.vocab_parallel.utils import (
19
+ VocabParallelOptions,
20
+ vp_reduce_correct_logit,
21
+ vp_reduce_lse,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class CCEParams:
27
+ targets: torch.Tensor
28
+ valids: torch.Tensor | None
29
+ softcap: float | None
30
+ reduction: str
31
+ filter_eps: float | None
32
+ shift: int
33
+ batch_shape: torch.Size
34
+ accum_e_fp32: bool
35
+ accum_c_fp32: bool
36
+ filter_e_grad: bool
37
+ filter_c_grad: bool
38
+ vocab_parallel_options: VocabParallelOptions | None
39
+ return_lse: bool
40
+
41
+
42
+ @torch.compile(fullgraph=True)
43
+ def sort_logit_avg(logit_avg: torch.Tensor) -> torch.Tensor:
44
+ return torch.argsort(logit_avg).to(torch.int32)
45
+
46
+
47
+ class LinearCrossEntropyFunction(torch.autograd.Function):
48
+ @staticmethod
49
+ @torch.amp.custom_fwd(device_type="cuda")
50
+ def forward(
51
+ ctx,
52
+ e: torch.Tensor,
53
+ c: torch.Tensor,
54
+ bias: torch.Tensor | None,
55
+ params: CCEParams,
56
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
57
+ needs_grad = e.requires_grad or c.requires_grad
58
+ if bias is not None:
59
+ needs_grad = needs_grad or bias.requires_grad
60
+
61
+ return_logit_avg = (
62
+ needs_grad
63
+ and params.filter_eps is not None
64
+ and (params.filter_c_grad or params.filter_e_grad)
65
+ )
66
+
67
+ e_info = TensorInfo(e.dtype, e.requires_grad)
68
+ c_info = TensorInfo(c.dtype, c.requires_grad)
69
+
70
+ bias_info = None
71
+ if bias is not None:
72
+ bias_info = TensorInfo(bias.dtype, bias.requires_grad)
73
+
74
+ if torch.is_autocast_enabled():
75
+ e = e.to(dtype=torch.get_autocast_gpu_dtype())
76
+ c = c.to(dtype=torch.get_autocast_gpu_dtype())
77
+
78
+ if bias is not None:
79
+ bias = bias.to(dtype=torch.get_autocast_gpu_dtype())
80
+
81
+ targets = params.targets
82
+ if (vp_opts := params.vocab_parallel_options) is not None:
83
+ is_my_target = (targets >= vp_opts.start) & (targets < vp_opts.stop)
84
+ targets = torch.where(
85
+ is_my_target,
86
+ targets - vp_opts.start,
87
+ ## NB
88
+ # The backward kernel already uses
89
+ # c.size(0) + 1 as the padding value to ensure that
90
+ # (targets.size(0) % block_size) == 0, so for targets
91
+ # that aren't in this VP rank's range, we can just consider
92
+ # them as padded and all work work as expected.
93
+ targets.new_full((), c.size(0) + 1),
94
+ )
95
+
96
+ ret = cce_lse_forward_kernel(
97
+ e=e,
98
+ c=c,
99
+ bias=bias,
100
+ valids=params.valids,
101
+ softcap=params.softcap,
102
+ return_logit_avg=return_logit_avg,
103
+ shift=params.shift,
104
+ targets=targets,
105
+ )
106
+ lse = ret.lse
107
+ assert ret.neg_correct_logit is not None
108
+ neg_correct_logit = ret.neg_correct_logit
109
+ logit_avg = ret.logit_avg
110
+
111
+ if params.vocab_parallel_options is not None:
112
+ lse = vp_reduce_lse(lse, pg=params.vocab_parallel_options.group)
113
+
114
+ neg_correct_logit = vp_reduce_correct_logit(
115
+ neg_correct_logit, pg=params.vocab_parallel_options.group, dtype=lse.dtype
116
+ )
117
+
118
+ nll = neg_correct_logit.add_(lse)
119
+
120
+ ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg)
121
+ ctx.params = params
122
+ ctx.e_info = e_info
123
+ ctx.c_info = c_info
124
+ ctx.bias_info = bias_info
125
+
126
+ if not params.return_lse:
127
+ ret_lse = None
128
+ else:
129
+ ret_lse = handle_reduction_none(params.batch_shape, params.valids, params.shift, lse)
130
+
131
+ reduction = params.reduction
132
+ if reduction == "mean":
133
+ loss = nll.mean()
134
+ elif reduction == "sum":
135
+ loss = nll.sum()
136
+ elif reduction == "none":
137
+ loss = handle_reduction_none(params.batch_shape, params.valids, params.shift, nll)
138
+ else:
139
+ raise ValueError(f"Unknown reduction {reduction}")
140
+
141
+ return loss, ret_lse
142
+
143
+ @staticmethod
144
+ @torch.amp.custom_bwd(device_type="cuda")
145
+ def backward(
146
+ ctx, grad_out: torch.Tensor, grad_lse_out: torch.Tensor | None
147
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, None]:
148
+ e, c, bias, lse, targets, valids, logit_avg = ctx.saved_tensors
149
+
150
+ if logit_avg is not None:
151
+ vocab_ordering = sort_logit_avg(logit_avg)
152
+ else:
153
+ vocab_ordering = None
154
+
155
+ params = cast(CCEParams, ctx.params)
156
+ reduction = params.reduction
157
+ if reduction == "mean":
158
+ grad_scale = 1 / max(lse.numel(), 1)
159
+ elif reduction == "sum":
160
+ grad_scale = 1.0
161
+ elif reduction == "none":
162
+ grad_scale = 1.0
163
+ grad_out = grad_out.contiguous().view(-1) # FIX: contiguity
164
+ else:
165
+ raise ValueError(f"Unknown reduction {reduction}")
166
+
167
+ if grad_lse_out is not None:
168
+ grad_lse_out = grad_lse_out.contiguous().view(-1) # FIX: contiguity
169
+
170
+ reduce_e_grad = False
171
+ pg = None
172
+ if (vp_opts := params.vocab_parallel_options) is not None:
173
+ is_my_target = (targets >= vp_opts.start) & (targets < vp_opts.stop)
174
+ targets = torch.where(
175
+ is_my_target,
176
+ targets - vp_opts.start,
177
+ ## NB
178
+ # The backward kernel already uses
179
+ # c.size(0) + 1 as the padding value to ensure that
180
+ # (targets.size(0) % block_size) == 0, so for targets
181
+ # that aren't in this VP rank's range, we can just consider
182
+ # them as padded and all work work as expected.
183
+ targets.new_full((), c.size(0) + 1),
184
+ )
185
+
186
+ reduce_e_grad = vp_opts.reduce_e_grad
187
+ pg = vp_opts.group
188
+
189
+ de, dc, dbias = cce_backward_kernel(
190
+ do=grad_out,
191
+ dlse=grad_lse_out,
192
+ e=e,
193
+ e_info=ctx.e_info,
194
+ c=c,
195
+ c_info=ctx.c_info,
196
+ bias=bias,
197
+ bias_info=ctx.bias_info,
198
+ lse=lse,
199
+ valids=valids,
200
+ softcap=params.softcap,
201
+ filter_eps=params.filter_eps,
202
+ targets=targets,
203
+ shift=params.shift,
204
+ vocab_ordering=vocab_ordering,
205
+ grad_scale=grad_scale,
206
+ accum_e_fp32=params.accum_e_fp32,
207
+ accum_c_fp32=params.accum_c_fp32,
208
+ filter_e_grad=params.filter_e_grad,
209
+ filter_c_grad=params.filter_c_grad,
210
+ reduce_e_grad=reduce_e_grad,
211
+ pg=pg,
212
+ )
213
+
214
+ return de, dc, dbias, None
215
+
216
+
217
+ def linear_cross_entropy_apply(
218
+ e: torch.Tensor,
219
+ c: torch.Tensor,
220
+ bias: torch.Tensor | None,
221
+ params: CCEParams,
222
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
223
+ loss, lse = cast(
224
+ tuple[torch.Tensor, torch.Tensor | None],
225
+ LinearCrossEntropyFunction.apply(e, c, bias, params),
226
+ )
227
+
228
+ if params.shift != 0 and params.reduction == "none":
229
+ loss = loss[..., params.shift :]
230
+
231
+ if params.return_lse and params.shift != 0:
232
+ assert lse is not None
233
+ lse = lse[..., params.shift :]
234
+
235
+ return loss, lse
236
+
237
+
238
+ @add_doc_start(LINEAR_CROSS_ENTROPY_DOC)
239
+ @add_doc_start(*(doc_str + "\n" for doc_str in CCE_OPTS_DOC))
240
+ def cce_linear_cross_entropy(
241
+ e: torch.Tensor,
242
+ c: torch.Tensor,
243
+ targets: torch.Tensor,
244
+ bias: torch.Tensor | None = None,
245
+ ignore_index: int = IGNORE_INDEX,
246
+ softcap: float | None = None,
247
+ reduction: str = "mean",
248
+ shift: bool | int = 0,
249
+ return_lse: bool = False,
250
+ filter_eps: float | str | None = "auto",
251
+ accum_e_fp32: bool = False,
252
+ accum_c_fp32: bool = False,
253
+ filter_e_grad: bool = True,
254
+ filter_c_grad: bool = True,
255
+ vocab_parallel_options: VocabParallelOptions | None = None,
256
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
257
+ assert e.size()[0:-1] == targets.size()
258
+ assert e.size(-1) == c.size(1)
259
+ if not torch.cuda.is_bf16_supported():
260
+ raise RuntimeError(
261
+ "Cut Cross Entropy requires an ampere GPU or newer. "
262
+ "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available."
263
+ )
264
+
265
+ batch_shape = targets.size()
266
+
267
+ e = e.contiguous()
268
+ targets = targets.contiguous()
269
+
270
+ shift = int(shift)
271
+ valids = _build_flat_valids(targets, ignore_index, shift)
272
+
273
+ e = e.flatten(0, -2)
274
+ targets = targets.flatten()
275
+
276
+ if (targets.data_ptr() % 16) != 0:
277
+ targets = torch.nn.functional.pad(targets, (0, 1))[:-1]
278
+
279
+ assert (targets.data_ptr() % 16) == 0
280
+ cce_params = CCEParams(
281
+ targets,
282
+ valids,
283
+ softcap,
284
+ reduction,
285
+ _handle_eps(
286
+ filter_eps, torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else e.dtype
287
+ ),
288
+ shift,
289
+ batch_shape,
290
+ accum_e_fp32,
291
+ accum_c_fp32,
292
+ filter_e_grad=filter_e_grad and filter_eps is not None,
293
+ filter_c_grad=filter_c_grad and filter_eps is not None,
294
+ vocab_parallel_options=vocab_parallel_options,
295
+ return_lse=return_lse,
296
+ )
297
+
298
+ return linear_cross_entropy_apply(e, c, bias, cce_params)