d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExpertCommunicationHandler(abc.ABC):
|
|
7
|
+
"""Abstract base class for Mixture-of-Experts communication strategies."""
|
|
8
|
+
|
|
9
|
+
@abc.abstractmethod
|
|
10
|
+
def dispatch(
|
|
11
|
+
self,
|
|
12
|
+
hidden_states: torch.Tensor,
|
|
13
|
+
topk_ids: torch.Tensor,
|
|
14
|
+
topk_weights: torch.Tensor
|
|
15
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
16
|
+
"""
|
|
17
|
+
Prepares and routes local hidden states to their target experts (possibly on other workers).
|
|
18
|
+
|
|
19
|
+
This process involves:
|
|
20
|
+
|
|
21
|
+
1. All-to-All Communication: Transfers hidden states to workers containing the assigned experts. States
|
|
22
|
+
assigned to multiple experts are replicated.
|
|
23
|
+
|
|
24
|
+
2. Permutation: Sorts tokens by expert ID to prepare for Grouped GEMM.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
hidden_states: Input tokens. Shape: `(num_tokens, hidden_size)`.
|
|
28
|
+
topk_ids: Indices of the top-k experts selected for each token. Shape: `(num_tokens, k)`.
|
|
29
|
+
topk_weights: Routing weights associated with the selected experts. Shape: `(num_tokens, k)`.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
A tuple containing:
|
|
33
|
+
|
|
34
|
+
- Permuted hidden states received by this rank. Shape: `(num_received_tokens, hidden_size)`.
|
|
35
|
+
- Permuted weights matching the hidden states order. Shape: `(num_received_tokens)`.
|
|
36
|
+
- Expert count tensor indicating how many tokens each local expert received. Shape: `(num_local_experts)`.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def combine(
|
|
43
|
+
self,
|
|
44
|
+
hidden_states: torch.Tensor
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Restores hidden states to their original order and location.
|
|
48
|
+
|
|
49
|
+
Undoes the permutation and performs the reverse All-to-All communication
|
|
50
|
+
to return processed results to the workers that originated the requests.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
hidden_states: The processed hidden states. Shape: `(num_received_tokens, hidden_size)`.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The combined hidden states with the original shape and order. Shape: `(num_tokens, hidden_size)`.
|
|
57
|
+
"""
|
|
58
|
+
...
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from deep_ep import Buffer, EventOverlap
|
|
5
|
+
from torch.autograd.function import FunctionCtx
|
|
6
|
+
|
|
7
|
+
from d9d.kernel.moe.indices_to_multihot import fused_indices_to_multihot
|
|
8
|
+
from d9d.kernel.moe.permute_with_probs import moe_permute_with_probs, moe_unpermute_mask
|
|
9
|
+
from d9d.module.block.moe.communications import ExpertCommunicationHandler
|
|
10
|
+
|
|
11
|
+
# see https://github.com/deepseek-ai/DeepEP/blob/main/README.md for examples
|
|
12
|
+
# TODO: implement computation/communication overlap for PP case
|
|
13
|
+
|
|
14
|
+
_buffer: Buffer | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_hidden_state_bytes(x: torch.Tensor) -> int:
|
|
18
|
+
"""
|
|
19
|
+
Calculates the byte size of a hidden state tensor row.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
x: Input tensor. Shape: `(?, hidden_size)`.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
return x.size(1) * max(x.element_size(), 2)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def init_deepep_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int):
|
|
29
|
+
"""
|
|
30
|
+
Initializes or expands the global DeepEP communication buffer.
|
|
31
|
+
|
|
32
|
+
Checks if the existing buffer is sufficient for the required hidden dimension
|
|
33
|
+
and process group size. If not, it allocates a new buffer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
group: The process group intended for communication.
|
|
37
|
+
hidden_bytes: Size of a single hidden state vector in bytes.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
global _buffer # noqa: PLW0603
|
|
41
|
+
num_nvl_bytes, num_rdma_bytes = 0, 0
|
|
42
|
+
for config in (
|
|
43
|
+
Buffer.get_dispatch_config(group.size()),
|
|
44
|
+
Buffer.get_combine_config(group.size()),
|
|
45
|
+
):
|
|
46
|
+
num_nvl_bytes = max(
|
|
47
|
+
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
|
48
|
+
)
|
|
49
|
+
num_rdma_bytes = max(
|
|
50
|
+
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Allocate buffer if not existed or not enough buffer
|
|
54
|
+
if (
|
|
55
|
+
_buffer is None
|
|
56
|
+
or _buffer.group != group
|
|
57
|
+
or _buffer.num_nvl_bytes < num_nvl_bytes
|
|
58
|
+
or _buffer.num_rdma_bytes < num_rdma_bytes
|
|
59
|
+
):
|
|
60
|
+
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class DeepEpDispatch(torch.autograd.Function):
|
|
64
|
+
"""Autograd function for the DeepEP Dispatch operation."""
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def forward(
|
|
68
|
+
ctx: FunctionCtx,
|
|
69
|
+
x: torch.Tensor,
|
|
70
|
+
topk_idx: torch.Tensor,
|
|
71
|
+
topk_weights: torch.Tensor,
|
|
72
|
+
num_experts: int
|
|
73
|
+
) -> tuple[
|
|
74
|
+
torch.Tensor,
|
|
75
|
+
torch.Tensor,
|
|
76
|
+
torch.Tensor,
|
|
77
|
+
list,
|
|
78
|
+
tuple,
|
|
79
|
+
EventOverlap
|
|
80
|
+
]:
|
|
81
|
+
previous_event = Buffer.capture()
|
|
82
|
+
(
|
|
83
|
+
num_tokens_per_rank,
|
|
84
|
+
num_tokens_per_rdma_rank,
|
|
85
|
+
num_tokens_per_expert,
|
|
86
|
+
is_token_in_rank,
|
|
87
|
+
previous_event
|
|
88
|
+
) = _buffer.get_dispatch_layout(
|
|
89
|
+
topk_idx, num_experts,
|
|
90
|
+
previous_event=previous_event,
|
|
91
|
+
async_finish=True,
|
|
92
|
+
allocate_on_comm_stream=True
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
(
|
|
96
|
+
recv_x,
|
|
97
|
+
recv_topk_idx,
|
|
98
|
+
recv_topk_weights,
|
|
99
|
+
num_recv_tokens_per_expert_list,
|
|
100
|
+
handle,
|
|
101
|
+
event
|
|
102
|
+
) = _buffer.dispatch(
|
|
103
|
+
x,
|
|
104
|
+
topk_idx=topk_idx,
|
|
105
|
+
topk_weights=topk_weights,
|
|
106
|
+
num_tokens_per_rank=num_tokens_per_rank,
|
|
107
|
+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
|
108
|
+
is_token_in_rank=is_token_in_rank,
|
|
109
|
+
num_tokens_per_expert=num_tokens_per_expert,
|
|
110
|
+
previous_event=previous_event,
|
|
111
|
+
async_finish=True,
|
|
112
|
+
allocate_on_comm_stream=True
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
event.current_stream_wait()
|
|
116
|
+
|
|
117
|
+
num_recv_tokens_per_expert_list = torch.tensor(num_recv_tokens_per_expert_list)
|
|
118
|
+
|
|
119
|
+
ctx.handle = handle
|
|
120
|
+
|
|
121
|
+
return (
|
|
122
|
+
recv_x,
|
|
123
|
+
recv_topk_idx,
|
|
124
|
+
recv_topk_weights,
|
|
125
|
+
num_recv_tokens_per_expert_list,
|
|
126
|
+
handle
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def backward(
|
|
131
|
+
ctx: FunctionCtx,
|
|
132
|
+
grad_recv_x: torch.Tensor,
|
|
133
|
+
grad_recv_topk_idx: torch.Tensor,
|
|
134
|
+
grad_recv_topk_weights: torch.Tensor,
|
|
135
|
+
grad_num_recv_tokens_per_expert_list: list,
|
|
136
|
+
grad_handle: Any
|
|
137
|
+
) -> tuple[
|
|
138
|
+
torch.Tensor,
|
|
139
|
+
None,
|
|
140
|
+
torch.Tensor,
|
|
141
|
+
None
|
|
142
|
+
]:
|
|
143
|
+
handle = ctx.handle
|
|
144
|
+
|
|
145
|
+
prev_event = Buffer.capture()
|
|
146
|
+
|
|
147
|
+
(
|
|
148
|
+
combined_grad_x,
|
|
149
|
+
combined_grad_recv_topk_weights,
|
|
150
|
+
event
|
|
151
|
+
) = _buffer.combine(
|
|
152
|
+
grad_recv_x.contiguous(),
|
|
153
|
+
handle,
|
|
154
|
+
topk_weights=grad_recv_topk_weights,
|
|
155
|
+
async_finish=True,
|
|
156
|
+
previous_event=prev_event,
|
|
157
|
+
allocate_on_comm_stream=True
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
event.current_stream_wait()
|
|
161
|
+
|
|
162
|
+
return combined_grad_x, None, combined_grad_recv_topk_weights, None
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class DeepEpCombine(torch.autograd.Function):
|
|
166
|
+
"""Autograd function for the DeepEP Combine operation."""
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def forward(
|
|
170
|
+
ctx: FunctionCtx,
|
|
171
|
+
x: torch.Tensor,
|
|
172
|
+
handle: Any
|
|
173
|
+
) -> torch.Tensor:
|
|
174
|
+
previous_event = Buffer.capture()
|
|
175
|
+
|
|
176
|
+
combined_x, _, event = _buffer.combine(
|
|
177
|
+
x,
|
|
178
|
+
handle,
|
|
179
|
+
async_finish=True,
|
|
180
|
+
previous_event=previous_event,
|
|
181
|
+
allocate_on_comm_stream=True
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
event.current_stream_wait()
|
|
185
|
+
|
|
186
|
+
ctx.handle = handle
|
|
187
|
+
|
|
188
|
+
return combined_x
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def backward(ctx: FunctionCtx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]:
|
|
192
|
+
handle = ctx.handle
|
|
193
|
+
|
|
194
|
+
previous_event = Buffer.capture()
|
|
195
|
+
|
|
196
|
+
grad_x, _, _, _, _, event = _buffer.dispatch(
|
|
197
|
+
grad_output.contiguous(),
|
|
198
|
+
handle=handle,
|
|
199
|
+
async_finish=True,
|
|
200
|
+
previous_event=previous_event,
|
|
201
|
+
allocate_on_comm_stream=True
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
event.current_stream_wait()
|
|
205
|
+
|
|
206
|
+
return grad_x, None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class DeepEpCommunicationHandler(ExpertCommunicationHandler):
|
|
210
|
+
"""Handles MoE communication using the high-performance DeepEP library."""
|
|
211
|
+
|
|
212
|
+
def __init__(self, num_experts: int):
|
|
213
|
+
"""Constructs the DeepEpCommunicationHandler."""
|
|
214
|
+
|
|
215
|
+
self._num_experts = num_experts
|
|
216
|
+
self._num_experts_per_shard = None # late-initialization
|
|
217
|
+
|
|
218
|
+
# == fields saved for post-dispatch ==
|
|
219
|
+
|
|
220
|
+
self._handle = None
|
|
221
|
+
self._hidden_shape_before_permute = None
|
|
222
|
+
self._unpermute_mapping = None
|
|
223
|
+
|
|
224
|
+
def setup(self, group: torch.distributed.ProcessGroup, hidden_size: int, hidden_dtype: torch.dtype):
|
|
225
|
+
"""
|
|
226
|
+
Initializes the backend buffer and calculates expert sharding.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
group: The process group containing all experts.
|
|
230
|
+
hidden_size: Dimensionality of the hidden states.
|
|
231
|
+
hidden_dtype: Data type of the hidden states.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
init_deepep_buffer(group, hidden_size * hidden_dtype.itemsize)
|
|
235
|
+
|
|
236
|
+
if self._num_experts % group.size() != 0:
|
|
237
|
+
raise ValueError("num_experts must be divisible by distributed group size")
|
|
238
|
+
|
|
239
|
+
self._num_experts_per_shard = self._num_experts // group.size()
|
|
240
|
+
|
|
241
|
+
def dispatch(
|
|
242
|
+
self,
|
|
243
|
+
hidden_states: torch.Tensor,
|
|
244
|
+
topk_ids: torch.Tensor,
|
|
245
|
+
topk_weights: torch.Tensor
|
|
246
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
247
|
+
(
|
|
248
|
+
hidden_states,
|
|
249
|
+
topk_ids,
|
|
250
|
+
topk_weights,
|
|
251
|
+
tokens_per_expert,
|
|
252
|
+
handle
|
|
253
|
+
) = DeepEpDispatch.apply(
|
|
254
|
+
hidden_states,
|
|
255
|
+
topk_ids,
|
|
256
|
+
topk_weights,
|
|
257
|
+
self._num_experts
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
routing_map, routing_probs = fused_indices_to_multihot(
|
|
261
|
+
topk_ids, topk_weights, self._num_experts_per_shard
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
self._hidden_shape_before_permute = hidden_states.shape
|
|
265
|
+
|
|
266
|
+
hidden_states, routing_probs, reverse_permute_map = moe_permute_with_probs(
|
|
267
|
+
hidden_states,
|
|
268
|
+
routing_probs,
|
|
269
|
+
routing_map,
|
|
270
|
+
num_out_tokens=tokens_per_expert.sum().item()
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
self._handle = handle
|
|
274
|
+
self._unpermute_mapping = reverse_permute_map
|
|
275
|
+
|
|
276
|
+
return hidden_states, routing_probs, tokens_per_expert
|
|
277
|
+
|
|
278
|
+
def combine(
|
|
279
|
+
self,
|
|
280
|
+
hidden_states: torch.Tensor
|
|
281
|
+
) -> torch.Tensor:
|
|
282
|
+
if self._handle is None:
|
|
283
|
+
raise ValueError("you fucked up moe communication order: you should dispatch first and after that combine")
|
|
284
|
+
|
|
285
|
+
hidden_states = moe_unpermute_mask(
|
|
286
|
+
hidden_states,
|
|
287
|
+
self._unpermute_mapping,
|
|
288
|
+
restore_shape=self._hidden_shape_before_permute,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
hidden_states = DeepEpCombine.apply(
|
|
292
|
+
hidden_states,
|
|
293
|
+
self._handle
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self._handle = None
|
|
297
|
+
self._unpermute_mapping = None
|
|
298
|
+
self._hidden_shape_before_permute = None
|
|
299
|
+
|
|
300
|
+
return hidden_states
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Size
|
|
5
|
+
|
|
6
|
+
from d9d.kernel.moe import (
|
|
7
|
+
fused_indices_to_multihot,
|
|
8
|
+
moe_permute_with_probs,
|
|
9
|
+
moe_unpermute_mask,
|
|
10
|
+
)
|
|
11
|
+
from d9d.module.block.moe.communications import ExpertCommunicationHandler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NoCommunicationHandler(ExpertCommunicationHandler):
|
|
15
|
+
"""
|
|
16
|
+
Handles MoE routing within a single device or when no cross-device routing is needed.
|
|
17
|
+
|
|
18
|
+
This handler does not perform network operations. It only permutes elements
|
|
19
|
+
mostly for local logical grouping or debugging.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, num_experts: int):
|
|
23
|
+
"""Constructs the NoCommunicationHandler."""
|
|
24
|
+
self._num_experts = num_experts
|
|
25
|
+
|
|
26
|
+
self._hidden_shape_before_permute: Size | None = None
|
|
27
|
+
self._unpermute_mapping: torch.Tensor | None = None
|
|
28
|
+
|
|
29
|
+
def dispatch(
|
|
30
|
+
self,
|
|
31
|
+
hidden_states: torch.Tensor,
|
|
32
|
+
topk_ids: torch.Tensor,
|
|
33
|
+
topk_weights: torch.Tensor
|
|
34
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
35
|
+
with torch.no_grad():
|
|
36
|
+
tokens_per_expert = torch.bincount(topk_ids.flatten(), minlength=self._num_experts).cpu()
|
|
37
|
+
|
|
38
|
+
routing_map, routing_probs = fused_indices_to_multihot(
|
|
39
|
+
topk_ids, topk_weights, self._num_experts
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self._hidden_shape_before_permute = hidden_states.shape
|
|
43
|
+
|
|
44
|
+
hidden_states, routing_probs, reverse_permute_map = moe_permute_with_probs(
|
|
45
|
+
hidden_states,
|
|
46
|
+
routing_probs,
|
|
47
|
+
routing_map,
|
|
48
|
+
num_out_tokens=cast(int, tokens_per_expert.sum().item())
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self._unpermute_mapping = reverse_permute_map
|
|
52
|
+
|
|
53
|
+
return hidden_states, routing_probs, tokens_per_expert
|
|
54
|
+
|
|
55
|
+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
if self._unpermute_mapping is None:
|
|
57
|
+
raise ValueError("Cannot run combine before running dispatch!")
|
|
58
|
+
|
|
59
|
+
hidden_states = moe_unpermute_mask(
|
|
60
|
+
hidden_states,
|
|
61
|
+
self._unpermute_mapping,
|
|
62
|
+
restore_shape=self._hidden_shape_before_permute,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self._unpermute_mapping = None
|
|
66
|
+
self._hidden_shape_before_permute = None
|
|
67
|
+
|
|
68
|
+
return hidden_states
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
from d9d.kernel.swiglu import silu_mul
|
|
5
|
+
from d9d.module.base import ModuleLateInit
|
|
6
|
+
|
|
7
|
+
from .grouped_linear import GroupedLinear
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GroupedSwiGLU(nn.Module, ModuleLateInit):
|
|
11
|
+
"""
|
|
12
|
+
Executes a collection of SwiGLU experts efficiently using Grouped GEMM.
|
|
13
|
+
|
|
14
|
+
This module implements the architectural pattern: `down_proj(SiLU(gate_proj(x)) * up_proj(x))`.
|
|
15
|
+
It applies this operation across multiple discrete experts in parallel without padding or masking.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
hidden_dim: int,
|
|
21
|
+
intermediate_dim: int,
|
|
22
|
+
num_experts: int
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Constructs the GroupedSwiGLU module.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
hidden_dim: Dimensionality of the input and output hidden states.
|
|
29
|
+
intermediate_dim: Dimensionality of the intermediate projection.
|
|
30
|
+
num_experts: Total number of experts managed by this local instance.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
super().__init__()
|
|
34
|
+
self._num_experts = num_experts
|
|
35
|
+
|
|
36
|
+
self.gate_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
|
|
37
|
+
self.up_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
|
|
38
|
+
self.down_proj = GroupedLinear(num_experts, intermediate_dim, hidden_dim)
|
|
39
|
+
|
|
40
|
+
def forward(
|
|
41
|
+
self,
|
|
42
|
+
permuted_x: torch.Tensor,
|
|
43
|
+
permuted_probs: torch.Tensor,
|
|
44
|
+
tokens_per_expert: torch.Tensor,
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Computes expert outputs for sorted input tokens.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
permuted_x: Input tokens sorted by their assigned expert.
|
|
51
|
+
Shape: `(total_tokens, hidden_dim)`.
|
|
52
|
+
permuted_probs: Routing weights/probabilities corresponding to the sorted tokens.
|
|
53
|
+
Shape: `(total_tokens)`.
|
|
54
|
+
tokens_per_expert: Number of tokens assigned to each consecutive expert. It is a CPU tensor.
|
|
55
|
+
Shape: `(num_experts)`.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
The computed and weighted output tokens (still permuted).
|
|
59
|
+
Shape: `(total_tokens, hidden_dim)`.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
if permuted_x.numel() == 0: # handle cases when there are no routed experts to this instance
|
|
63
|
+
return permuted_x
|
|
64
|
+
|
|
65
|
+
probs = permuted_probs[:, None].to(permuted_x.dtype)
|
|
66
|
+
values = self.down_proj(
|
|
67
|
+
silu_mul(
|
|
68
|
+
self.gate_proj(permuted_x, tokens_per_expert),
|
|
69
|
+
self.up_proj(permuted_x, tokens_per_expert)
|
|
70
|
+
),
|
|
71
|
+
tokens_per_expert
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return probs * values
|
|
75
|
+
|
|
76
|
+
def reset_parameters(self):
|
|
77
|
+
"""Resets parameters for all internal linear projections."""
|
|
78
|
+
|
|
79
|
+
self.gate_proj.reset_parameters()
|
|
80
|
+
self.up_proj.reset_parameters()
|
|
81
|
+
self.down_proj.reset_parameters()
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.distributed.tensor import DTensor
|
|
6
|
+
|
|
7
|
+
from d9d.core.autograd import GradDirection
|
|
8
|
+
from d9d.kernel.gmm import gmm
|
|
9
|
+
from d9d.module.base import ModuleLateInit
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GroupedLinear(nn.Module, ModuleLateInit):
|
|
13
|
+
"""
|
|
14
|
+
Applies a linear transformation using Grouped GEMM (Generalized Matrix Multiplication).
|
|
15
|
+
|
|
16
|
+
This module allows efficient execution of multiple linear layers (experts) in parallel, where each expert
|
|
17
|
+
processes a variable number of tokens.
|
|
18
|
+
It is the computational core of the Mixture-of-Experts layer.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
n_groups: int,
|
|
24
|
+
in_features: int,
|
|
25
|
+
out_features: int,
|
|
26
|
+
device: torch.device | str | None = None,
|
|
27
|
+
dtype: torch.dtype | None = None
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Constructs the GroupedLinear layer.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
n_groups: Number of groups (experts).
|
|
34
|
+
in_features: Input hidden size.
|
|
35
|
+
out_features: Output hidden size.
|
|
36
|
+
device: Target device.
|
|
37
|
+
dtype: Target data type.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.weight = nn.Parameter(torch.empty(n_groups, in_features, out_features,
|
|
41
|
+
device=device, dtype=dtype))
|
|
42
|
+
|
|
43
|
+
self.n_groups = n_groups
|
|
44
|
+
self.in_features = in_features
|
|
45
|
+
self.out_features = out_features
|
|
46
|
+
|
|
47
|
+
self.reset_parameters()
|
|
48
|
+
|
|
49
|
+
def forward(self, x: torch.Tensor, x_groups: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
"""
|
|
51
|
+
Performs the grouped matrix multiplication.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x: Flattened input tensor containing tokens for all groups.
|
|
55
|
+
Shape: `(total_tokens, in_features)`.
|
|
56
|
+
x_groups: CPU Tensor indicating the number of tokens assigned to each group.
|
|
57
|
+
Must sum to `total_tokens`. Shape: `(n_groups,)`.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The output tensor. Shape: `(total_tokens, out_features)`.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
weight: torch.Tensor = self.weight
|
|
64
|
+
|
|
65
|
+
if isinstance(weight, DTensor):
|
|
66
|
+
weight = weight.to_local()
|
|
67
|
+
|
|
68
|
+
return gmm(
|
|
69
|
+
x,
|
|
70
|
+
weight,
|
|
71
|
+
x_groups,
|
|
72
|
+
a_grad_direction=GradDirection.inputs,
|
|
73
|
+
b_grad_direction=GradDirection.weight
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def reset_parameters(self):
|
|
77
|
+
"""Initializes weights using a uniform distribution based on input features."""
|
|
78
|
+
nn.init.uniform_(self.weight, -1 / math.sqrt(self.in_features), 1 / math.sqrt(self.in_features))
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.distributed import ProcessGroup
|
|
4
|
+
|
|
5
|
+
from d9d.module.base import ModuleLateInit
|
|
6
|
+
|
|
7
|
+
from .communications import (
|
|
8
|
+
DeepEpCommunicationHandler,
|
|
9
|
+
ExpertCommunicationHandler,
|
|
10
|
+
NoCommunicationHandler,
|
|
11
|
+
)
|
|
12
|
+
from .grouped_experts import GroupedSwiGLU
|
|
13
|
+
from .router import TopKRouter
|
|
14
|
+
|
|
15
|
+
# TODO: implement expert bias
|
|
16
|
+
# TODO: shared experts
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MoELayer(nn.Module, ModuleLateInit):
|
|
20
|
+
"""
|
|
21
|
+
A complete Mixture-of-Experts (MoE) block comprising routing, communication, and computation.
|
|
22
|
+
|
|
23
|
+
This layer integrates:
|
|
24
|
+
|
|
25
|
+
1. **Router**: Selects experts for each token.
|
|
26
|
+
2. **Communicator**: Handles token dispatch to local or remote experts (EP).
|
|
27
|
+
3. **Experts**: Performs parallelized computation (Grouped SwiGLU).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
hidden_dim: int,
|
|
33
|
+
intermediate_dim_grouped: int,
|
|
34
|
+
num_grouped_experts: int,
|
|
35
|
+
top_k: int,
|
|
36
|
+
router_renormalize_probabilities: bool
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Constructs the MoELayer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
hidden_dim: Hidden size.
|
|
43
|
+
intermediate_dim_grouped: Intermediate dimension for the Expert FFNs.
|
|
44
|
+
num_grouped_experts: Total number of experts.
|
|
45
|
+
top_k: Number of experts to route each token to.
|
|
46
|
+
router_renormalize_probabilities: Configures router probability normalization behavior.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.router = TopKRouter(
|
|
51
|
+
dim=hidden_dim, num_experts=num_grouped_experts, top_k=top_k,
|
|
52
|
+
renormalize_probabilities=router_renormalize_probabilities
|
|
53
|
+
)
|
|
54
|
+
self.grouped_experts = GroupedSwiGLU(
|
|
55
|
+
hidden_dim=hidden_dim,
|
|
56
|
+
intermediate_dim=intermediate_dim_grouped,
|
|
57
|
+
num_experts=num_grouped_experts
|
|
58
|
+
)
|
|
59
|
+
self._communicator: ExpertCommunicationHandler = NoCommunicationHandler(num_grouped_experts)
|
|
60
|
+
|
|
61
|
+
self._num_grouped_experts = num_grouped_experts
|
|
62
|
+
self._hidden_dim = hidden_dim
|
|
63
|
+
|
|
64
|
+
self.tokens_per_expert = nn.Buffer(torch.empty((num_grouped_experts,), dtype=torch.int64), persistent=False)
|
|
65
|
+
|
|
66
|
+
def enable_distributed_communicator(self, group: ProcessGroup):
|
|
67
|
+
"""
|
|
68
|
+
Switches from local no-op communication to distributed DeepEP communication.
|
|
69
|
+
|
|
70
|
+
This should be called during model initialization if the model is running in a
|
|
71
|
+
distributed Expert Parallel environment.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
group: The PyTorch process group spanning the expert parallel ranks.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
communicator = DeepEpCommunicationHandler(num_experts=self._num_grouped_experts)
|
|
78
|
+
communicator.setup(group, self._hidden_dim, self.router.gate.weight.dtype)
|
|
79
|
+
self._communicator = communicator
|
|
80
|
+
|
|
81
|
+
@torch.no_grad()
|
|
82
|
+
def _update_tokens_per_expert(self, expert_indices: torch.Tensor):
|
|
83
|
+
self.tokens_per_expert.add_(expert_indices.view(-1).bincount(minlength=self._num_grouped_experts))
|
|
84
|
+
|
|
85
|
+
@torch.no_grad()
|
|
86
|
+
def reset_stats(self):
|
|
87
|
+
"""Resets the expert load balancing counters."""
|
|
88
|
+
self.tokens_per_expert.zero_()
|
|
89
|
+
|
|
90
|
+
def forward(
|
|
91
|
+
self,
|
|
92
|
+
hidden_states: torch.Tensor
|
|
93
|
+
) -> torch.Tensor:
|
|
94
|
+
"""
|
|
95
|
+
Routes tokens to experts, computes, and combines results.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
hidden_states: Input tensor. Shape: `(batch_size, seq_len, hidden_dim)`.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Output tensor combined from experts. Shape: `(batch_size, seq_len, hidden_dim)`.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
old_shape = hidden_states.shape
|
|
105
|
+
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
|
106
|
+
expert_indices, expert_scores = self.router(hidden_states)
|
|
107
|
+
self._update_tokens_per_expert(expert_indices)
|
|
108
|
+
hidden_states, expert_scores, expert_count = self._communicator.dispatch(
|
|
109
|
+
hidden_states, expert_indices, expert_scores
|
|
110
|
+
)
|
|
111
|
+
hidden_states = self.grouped_experts(hidden_states, expert_scores, expert_count)
|
|
112
|
+
hidden_states = self._communicator.combine(hidden_states)
|
|
113
|
+
hidden_states = hidden_states.reshape(*old_shape)
|
|
114
|
+
|
|
115
|
+
return hidden_states
|
|
116
|
+
|
|
117
|
+
def reset_parameters(self):
|
|
118
|
+
"""Resets module parameters."""
|
|
119
|
+
self.router.reset_parameters()
|
|
120
|
+
self.grouped_experts.reset_parameters()
|
|
121
|
+
|
|
122
|
+
nn.init.zeros_(self.tokens_per_expert)
|