d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import tarfile
|
|
2
|
+
import time
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import torch.profiler as tprof
|
|
7
|
+
|
|
8
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Profiler:
|
|
12
|
+
"""
|
|
13
|
+
Manages distributed performance profiling using PyTorch Profiler.
|
|
14
|
+
|
|
15
|
+
This class wraps `torch.profiler` to provide automatic trace exporting,
|
|
16
|
+
compression, and file naming consistent with the distributed DeviceMesh
|
|
17
|
+
topology. It configures the schedule to repeat periodically based on
|
|
18
|
+
the provided step counts.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
save_dir: Path,
|
|
24
|
+
period_steps: int,
|
|
25
|
+
warmup_steps: int,
|
|
26
|
+
active_steps: int,
|
|
27
|
+
dist_context: DistributedContext
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Constructs a Profiler object.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
save_dir: Directory where trace files will be saved.
|
|
34
|
+
period_steps: Total length of a profiling cycle (wait + warmup + active).
|
|
35
|
+
warmup_steps: Number of steps to ignore before recording to allow for warming-up.
|
|
36
|
+
active_steps: Number of steps to actively record traces.
|
|
37
|
+
dist_context: The distributed context object.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
self._save_dir = save_dir
|
|
41
|
+
self._period = period_steps
|
|
42
|
+
self._warmup = warmup_steps
|
|
43
|
+
self._active = active_steps
|
|
44
|
+
self._dist_context = dist_context
|
|
45
|
+
|
|
46
|
+
def _get_save_file_name(self) -> str:
|
|
47
|
+
if self._dist_context.mesh_params.is_distributed:
|
|
48
|
+
mesh_regular = self._dist_context.mesh_for(REGULAR_DOMAIN)
|
|
49
|
+
coord = mesh_regular.get_coordinate()
|
|
50
|
+
if coord is None:
|
|
51
|
+
raise RuntimeError("Invalid mesh")
|
|
52
|
+
coord_str = "-".join(map(str, coord))
|
|
53
|
+
rank = mesh_regular.get_rank()
|
|
54
|
+
return f"rank-{rank}-coord-{coord_str}-trace.json"
|
|
55
|
+
else:
|
|
56
|
+
return "trace.json"
|
|
57
|
+
|
|
58
|
+
def _dump_trace(self, prof: tprof.profile):
|
|
59
|
+
save_dir = self._save_dir / f"step_{prof.step_num}"
|
|
60
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
save_file = save_dir / self._get_save_file_name()
|
|
62
|
+
|
|
63
|
+
begin = time.monotonic()
|
|
64
|
+
|
|
65
|
+
prof.export_chrome_trace(str(save_file))
|
|
66
|
+
with tarfile.open(save_file.with_suffix(".tar.gz"), "w:gz") as tar:
|
|
67
|
+
tar.add(save_file, arcname=save_file.name)
|
|
68
|
+
save_file.unlink()
|
|
69
|
+
|
|
70
|
+
end = time.monotonic()
|
|
71
|
+
|
|
72
|
+
self._dist_context.logger.info(
|
|
73
|
+
f"Finished dumping profiler traces in {end - begin:.2f} seconds"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@contextmanager
|
|
77
|
+
def open(self, start_step: int):
|
|
78
|
+
"""
|
|
79
|
+
Opens a context manager for profiling execution.
|
|
80
|
+
|
|
81
|
+
This sets up the `torch.profiler.profile` with a schedule derived from
|
|
82
|
+
the initialization parameters. It captures both CPU and CUDA activities,
|
|
83
|
+
records shapes, and tracks stack traces.
|
|
84
|
+
|
|
85
|
+
When the schedule triggers `on_trace_ready`, the trace is automatically
|
|
86
|
+
exported to the `save_dir`, compressed into a `.tar.gz` file, and the
|
|
87
|
+
raw JSON is removed to save space.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
start_step: The current global step number to initialize the
|
|
91
|
+
profiler state.
|
|
92
|
+
|
|
93
|
+
Yields:
|
|
94
|
+
The configured torch profiler instance.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
wait = self._period - (self._active + self._warmup)
|
|
98
|
+
warmup = self._warmup
|
|
99
|
+
active = self._active
|
|
100
|
+
|
|
101
|
+
with tprof.profile(
|
|
102
|
+
activities=[
|
|
103
|
+
tprof.ProfilerActivity.CPU,
|
|
104
|
+
tprof.ProfilerActivity.CUDA
|
|
105
|
+
],
|
|
106
|
+
schedule=tprof.schedule(wait=wait, warmup=warmup, active=active),
|
|
107
|
+
on_trace_ready=self._dump_trace,
|
|
108
|
+
record_shapes=True,
|
|
109
|
+
with_stack=True
|
|
110
|
+
) as profiler:
|
|
111
|
+
profiler.step_num = start_step
|
|
112
|
+
yield profiler
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
4
|
+
|
|
5
|
+
from d9d.core.dist_context import DistributedContext
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def state_dict_main_process(dist_context: DistributedContext, obj: Stateful) -> dict[str, Any]:
|
|
9
|
+
"""
|
|
10
|
+
Retrieves the state dictionary of an object only on the main process.
|
|
11
|
+
|
|
12
|
+
This is useful for checkpointing components that track global state primarily
|
|
13
|
+
managed by the driver/main rank, ensuring that non-main ranks return an empty
|
|
14
|
+
state to avoid duplication or synchronization issues during checkpointing.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
dist_context: The distributed context to check for main process status.
|
|
18
|
+
obj: The stateful object to serialize.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
A dictionary containing the object's state under the 'main_process' key on
|
|
22
|
+
the main rank, and an empty dictionary on all other ranks.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if dist_context.is_main_process:
|
|
26
|
+
return {
|
|
27
|
+
"main_process": obj.state_dict()
|
|
28
|
+
}
|
|
29
|
+
else:
|
|
30
|
+
return {}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def load_state_dict_main_process(dist_context: DistributedContext, obj: Stateful, state_dict: dict[str, Any]):
|
|
34
|
+
"""
|
|
35
|
+
Restores the state dictionary of an object only on the main process.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
dist_context: The distributed context to check for main process status.
|
|
39
|
+
obj: The stateful object to restore.
|
|
40
|
+
state_dict: The state dictionary created by "state_dict_main_process" function.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
if dist_context.is_main_process:
|
|
44
|
+
obj.load_state_dict(state_dict["main_process"])
|
d9d/kernel/__init__.py
ADDED
|
File without changes
|
d9d/kernel/cce/cce.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.amp
|
|
7
|
+
|
|
8
|
+
from cut_cross_entropy.cce_backward import cce_backward_kernel
|
|
9
|
+
from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel
|
|
10
|
+
from cut_cross_entropy.constants import IGNORE_INDEX
|
|
11
|
+
from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start
|
|
12
|
+
from cut_cross_entropy.utils import (
|
|
13
|
+
TensorInfo,
|
|
14
|
+
_build_flat_valids,
|
|
15
|
+
_handle_eps,
|
|
16
|
+
handle_reduction_none,
|
|
17
|
+
)
|
|
18
|
+
from cut_cross_entropy.vocab_parallel.utils import (
|
|
19
|
+
VocabParallelOptions,
|
|
20
|
+
vp_reduce_correct_logit,
|
|
21
|
+
vp_reduce_lse,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class CCEParams:
|
|
27
|
+
targets: torch.Tensor
|
|
28
|
+
valids: torch.Tensor | None
|
|
29
|
+
softcap: float | None
|
|
30
|
+
reduction: str
|
|
31
|
+
filter_eps: float | None
|
|
32
|
+
shift: int
|
|
33
|
+
batch_shape: torch.Size
|
|
34
|
+
accum_e_fp32: bool
|
|
35
|
+
accum_c_fp32: bool
|
|
36
|
+
filter_e_grad: bool
|
|
37
|
+
filter_c_grad: bool
|
|
38
|
+
vocab_parallel_options: VocabParallelOptions | None
|
|
39
|
+
return_lse: bool
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@torch.compile(fullgraph=True)
|
|
43
|
+
def sort_logit_avg(logit_avg: torch.Tensor) -> torch.Tensor:
|
|
44
|
+
return torch.argsort(logit_avg).to(torch.int32)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LinearCrossEntropyFunction(torch.autograd.Function):
|
|
48
|
+
@staticmethod
|
|
49
|
+
@torch.amp.custom_fwd(device_type="cuda")
|
|
50
|
+
def forward(
|
|
51
|
+
ctx,
|
|
52
|
+
e: torch.Tensor,
|
|
53
|
+
c: torch.Tensor,
|
|
54
|
+
bias: torch.Tensor | None,
|
|
55
|
+
params: CCEParams,
|
|
56
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
57
|
+
needs_grad = e.requires_grad or c.requires_grad
|
|
58
|
+
if bias is not None:
|
|
59
|
+
needs_grad = needs_grad or bias.requires_grad
|
|
60
|
+
|
|
61
|
+
return_logit_avg = (
|
|
62
|
+
needs_grad
|
|
63
|
+
and params.filter_eps is not None
|
|
64
|
+
and (params.filter_c_grad or params.filter_e_grad)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
e_info = TensorInfo(e.dtype, e.requires_grad)
|
|
68
|
+
c_info = TensorInfo(c.dtype, c.requires_grad)
|
|
69
|
+
|
|
70
|
+
bias_info = None
|
|
71
|
+
if bias is not None:
|
|
72
|
+
bias_info = TensorInfo(bias.dtype, bias.requires_grad)
|
|
73
|
+
|
|
74
|
+
if torch.is_autocast_enabled():
|
|
75
|
+
e = e.to(dtype=torch.get_autocast_gpu_dtype())
|
|
76
|
+
c = c.to(dtype=torch.get_autocast_gpu_dtype())
|
|
77
|
+
|
|
78
|
+
if bias is not None:
|
|
79
|
+
bias = bias.to(dtype=torch.get_autocast_gpu_dtype())
|
|
80
|
+
|
|
81
|
+
targets = params.targets
|
|
82
|
+
if (vp_opts := params.vocab_parallel_options) is not None:
|
|
83
|
+
is_my_target = (targets >= vp_opts.start) & (targets < vp_opts.stop)
|
|
84
|
+
targets = torch.where(
|
|
85
|
+
is_my_target,
|
|
86
|
+
targets - vp_opts.start,
|
|
87
|
+
## NB
|
|
88
|
+
# The backward kernel already uses
|
|
89
|
+
# c.size(0) + 1 as the padding value to ensure that
|
|
90
|
+
# (targets.size(0) % block_size) == 0, so for targets
|
|
91
|
+
# that aren't in this VP rank's range, we can just consider
|
|
92
|
+
# them as padded and all work work as expected.
|
|
93
|
+
targets.new_full((), c.size(0) + 1),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
ret = cce_lse_forward_kernel(
|
|
97
|
+
e=e,
|
|
98
|
+
c=c,
|
|
99
|
+
bias=bias,
|
|
100
|
+
valids=params.valids,
|
|
101
|
+
softcap=params.softcap,
|
|
102
|
+
return_logit_avg=return_logit_avg,
|
|
103
|
+
shift=params.shift,
|
|
104
|
+
targets=targets,
|
|
105
|
+
)
|
|
106
|
+
lse = ret.lse
|
|
107
|
+
assert ret.neg_correct_logit is not None
|
|
108
|
+
neg_correct_logit = ret.neg_correct_logit
|
|
109
|
+
logit_avg = ret.logit_avg
|
|
110
|
+
|
|
111
|
+
if params.vocab_parallel_options is not None:
|
|
112
|
+
lse = vp_reduce_lse(lse, pg=params.vocab_parallel_options.group)
|
|
113
|
+
|
|
114
|
+
neg_correct_logit = vp_reduce_correct_logit(
|
|
115
|
+
neg_correct_logit, pg=params.vocab_parallel_options.group, dtype=lse.dtype
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
nll = neg_correct_logit.add_(lse)
|
|
119
|
+
|
|
120
|
+
ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg)
|
|
121
|
+
ctx.params = params
|
|
122
|
+
ctx.e_info = e_info
|
|
123
|
+
ctx.c_info = c_info
|
|
124
|
+
ctx.bias_info = bias_info
|
|
125
|
+
|
|
126
|
+
if not params.return_lse:
|
|
127
|
+
ret_lse = None
|
|
128
|
+
else:
|
|
129
|
+
ret_lse = handle_reduction_none(params.batch_shape, params.valids, params.shift, lse)
|
|
130
|
+
|
|
131
|
+
reduction = params.reduction
|
|
132
|
+
if reduction == "mean":
|
|
133
|
+
loss = nll.mean()
|
|
134
|
+
elif reduction == "sum":
|
|
135
|
+
loss = nll.sum()
|
|
136
|
+
elif reduction == "none":
|
|
137
|
+
loss = handle_reduction_none(params.batch_shape, params.valids, params.shift, nll)
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Unknown reduction {reduction}")
|
|
140
|
+
|
|
141
|
+
return loss, ret_lse
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
@torch.amp.custom_bwd(device_type="cuda")
|
|
145
|
+
def backward(
|
|
146
|
+
ctx, grad_out: torch.Tensor, grad_lse_out: torch.Tensor | None
|
|
147
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, None]:
|
|
148
|
+
e, c, bias, lse, targets, valids, logit_avg = ctx.saved_tensors
|
|
149
|
+
|
|
150
|
+
if logit_avg is not None:
|
|
151
|
+
vocab_ordering = sort_logit_avg(logit_avg)
|
|
152
|
+
else:
|
|
153
|
+
vocab_ordering = None
|
|
154
|
+
|
|
155
|
+
params = cast(CCEParams, ctx.params)
|
|
156
|
+
reduction = params.reduction
|
|
157
|
+
if reduction == "mean":
|
|
158
|
+
grad_scale = 1 / max(lse.numel(), 1)
|
|
159
|
+
elif reduction == "sum":
|
|
160
|
+
grad_scale = 1.0
|
|
161
|
+
elif reduction == "none":
|
|
162
|
+
grad_scale = 1.0
|
|
163
|
+
grad_out = grad_out.contiguous().view(-1) # FIX: contiguity
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(f"Unknown reduction {reduction}")
|
|
166
|
+
|
|
167
|
+
if grad_lse_out is not None:
|
|
168
|
+
grad_lse_out = grad_lse_out.contiguous().view(-1) # FIX: contiguity
|
|
169
|
+
|
|
170
|
+
reduce_e_grad = False
|
|
171
|
+
pg = None
|
|
172
|
+
if (vp_opts := params.vocab_parallel_options) is not None:
|
|
173
|
+
is_my_target = (targets >= vp_opts.start) & (targets < vp_opts.stop)
|
|
174
|
+
targets = torch.where(
|
|
175
|
+
is_my_target,
|
|
176
|
+
targets - vp_opts.start,
|
|
177
|
+
## NB
|
|
178
|
+
# The backward kernel already uses
|
|
179
|
+
# c.size(0) + 1 as the padding value to ensure that
|
|
180
|
+
# (targets.size(0) % block_size) == 0, so for targets
|
|
181
|
+
# that aren't in this VP rank's range, we can just consider
|
|
182
|
+
# them as padded and all work work as expected.
|
|
183
|
+
targets.new_full((), c.size(0) + 1),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
reduce_e_grad = vp_opts.reduce_e_grad
|
|
187
|
+
pg = vp_opts.group
|
|
188
|
+
|
|
189
|
+
de, dc, dbias = cce_backward_kernel(
|
|
190
|
+
do=grad_out,
|
|
191
|
+
dlse=grad_lse_out,
|
|
192
|
+
e=e,
|
|
193
|
+
e_info=ctx.e_info,
|
|
194
|
+
c=c,
|
|
195
|
+
c_info=ctx.c_info,
|
|
196
|
+
bias=bias,
|
|
197
|
+
bias_info=ctx.bias_info,
|
|
198
|
+
lse=lse,
|
|
199
|
+
valids=valids,
|
|
200
|
+
softcap=params.softcap,
|
|
201
|
+
filter_eps=params.filter_eps,
|
|
202
|
+
targets=targets,
|
|
203
|
+
shift=params.shift,
|
|
204
|
+
vocab_ordering=vocab_ordering,
|
|
205
|
+
grad_scale=grad_scale,
|
|
206
|
+
accum_e_fp32=params.accum_e_fp32,
|
|
207
|
+
accum_c_fp32=params.accum_c_fp32,
|
|
208
|
+
filter_e_grad=params.filter_e_grad,
|
|
209
|
+
filter_c_grad=params.filter_c_grad,
|
|
210
|
+
reduce_e_grad=reduce_e_grad,
|
|
211
|
+
pg=pg,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return de, dc, dbias, None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def linear_cross_entropy_apply(
|
|
218
|
+
e: torch.Tensor,
|
|
219
|
+
c: torch.Tensor,
|
|
220
|
+
bias: torch.Tensor | None,
|
|
221
|
+
params: CCEParams,
|
|
222
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
223
|
+
loss, lse = cast(
|
|
224
|
+
tuple[torch.Tensor, torch.Tensor | None],
|
|
225
|
+
LinearCrossEntropyFunction.apply(e, c, bias, params),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if params.shift != 0 and params.reduction == "none":
|
|
229
|
+
loss = loss[..., params.shift :]
|
|
230
|
+
|
|
231
|
+
if params.return_lse and params.shift != 0:
|
|
232
|
+
assert lse is not None
|
|
233
|
+
lse = lse[..., params.shift :]
|
|
234
|
+
|
|
235
|
+
return loss, lse
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@add_doc_start(LINEAR_CROSS_ENTROPY_DOC)
|
|
239
|
+
@add_doc_start(*(doc_str + "\n" for doc_str in CCE_OPTS_DOC))
|
|
240
|
+
def cce_linear_cross_entropy(
|
|
241
|
+
e: torch.Tensor,
|
|
242
|
+
c: torch.Tensor,
|
|
243
|
+
targets: torch.Tensor,
|
|
244
|
+
bias: torch.Tensor | None = None,
|
|
245
|
+
ignore_index: int = IGNORE_INDEX,
|
|
246
|
+
softcap: float | None = None,
|
|
247
|
+
reduction: str = "mean",
|
|
248
|
+
shift: bool | int = 0,
|
|
249
|
+
return_lse: bool = False,
|
|
250
|
+
filter_eps: float | str | None = "auto",
|
|
251
|
+
accum_e_fp32: bool = False,
|
|
252
|
+
accum_c_fp32: bool = False,
|
|
253
|
+
filter_e_grad: bool = True,
|
|
254
|
+
filter_c_grad: bool = True,
|
|
255
|
+
vocab_parallel_options: VocabParallelOptions | None = None,
|
|
256
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
257
|
+
assert e.size()[0:-1] == targets.size()
|
|
258
|
+
assert e.size(-1) == c.size(1)
|
|
259
|
+
if not torch.cuda.is_bf16_supported():
|
|
260
|
+
raise RuntimeError(
|
|
261
|
+
"Cut Cross Entropy requires an ampere GPU or newer. "
|
|
262
|
+
"Consider using torch_compile_linear_cross_entropy for scenarios where one is not available."
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
batch_shape = targets.size()
|
|
266
|
+
|
|
267
|
+
e = e.contiguous()
|
|
268
|
+
targets = targets.contiguous()
|
|
269
|
+
|
|
270
|
+
shift = int(shift)
|
|
271
|
+
valids = _build_flat_valids(targets, ignore_index, shift)
|
|
272
|
+
|
|
273
|
+
e = e.flatten(0, -2)
|
|
274
|
+
targets = targets.flatten()
|
|
275
|
+
|
|
276
|
+
if (targets.data_ptr() % 16) != 0:
|
|
277
|
+
targets = torch.nn.functional.pad(targets, (0, 1))[:-1]
|
|
278
|
+
|
|
279
|
+
assert (targets.data_ptr() % 16) == 0
|
|
280
|
+
cce_params = CCEParams(
|
|
281
|
+
targets,
|
|
282
|
+
valids,
|
|
283
|
+
softcap,
|
|
284
|
+
reduction,
|
|
285
|
+
_handle_eps(
|
|
286
|
+
filter_eps, torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else e.dtype
|
|
287
|
+
),
|
|
288
|
+
shift,
|
|
289
|
+
batch_shape,
|
|
290
|
+
accum_e_fp32,
|
|
291
|
+
accum_c_fp32,
|
|
292
|
+
filter_e_grad=filter_e_grad and filter_eps is not None,
|
|
293
|
+
filter_c_grad=filter_c_grad and filter_eps is not None,
|
|
294
|
+
vocab_parallel_options=vocab_parallel_options,
|
|
295
|
+
return_lse=return_lse,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
return linear_cross_entropy_apply(e, c, bias, cce_params)
|