d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1035 @@
|
|
|
1
|
+
# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/permutation.py
|
|
2
|
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import triton.language as tl
|
|
9
|
+
from triton.language.standard import _log2
|
|
10
|
+
|
|
11
|
+
from d9d.kernel.general import get_int_dtype
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@triton.jit
|
|
15
|
+
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
|
|
16
|
+
n_outer: tl.constexpr = x.numel >> n_dims
|
|
17
|
+
shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
|
|
18
|
+
y = tl.reshape(x, shape)
|
|
19
|
+
z = tl.reshape(indices, shape)
|
|
20
|
+
|
|
21
|
+
mask = tl.arange(0, 2)[None, :, None]
|
|
22
|
+
|
|
23
|
+
l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
|
|
24
|
+
x.dtype
|
|
25
|
+
)
|
|
26
|
+
r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
|
|
27
|
+
x.dtype
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
|
|
31
|
+
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
|
|
32
|
+
|
|
33
|
+
idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
34
|
+
|
|
35
|
+
il_value = l_value.to(idtype, bitcast=True)
|
|
36
|
+
ir_value = r_value.to(idtype, bitcast=True)
|
|
37
|
+
ix = x.to(idtype, bitcast=True)
|
|
38
|
+
|
|
39
|
+
flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
|
|
40
|
+
ret = ix ^ flag1
|
|
41
|
+
flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
|
|
42
|
+
ind = indices ^ flag2
|
|
43
|
+
|
|
44
|
+
return ret.to(x.dtype, bitcast=True), ind
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@triton.jit
|
|
48
|
+
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
|
|
49
|
+
n_outer: tl.constexpr = x.numel >> n_dims
|
|
50
|
+
tl.static_assert(stage <= n_dims)
|
|
51
|
+
"""
|
|
52
|
+
order_type 0 == ascending
|
|
53
|
+
order_type 1 == descending
|
|
54
|
+
order_type 2 == alternating
|
|
55
|
+
"""
|
|
56
|
+
if order == 2:
|
|
57
|
+
shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
|
|
58
|
+
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
|
|
59
|
+
else:
|
|
60
|
+
flip = tl.full(x.shape, value=order, dtype=tl.int32)
|
|
61
|
+
for i in tl.static_range(stage):
|
|
62
|
+
x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
|
|
63
|
+
return x, indices
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@triton.jit
|
|
67
|
+
def _argsort(x, indices, n_dims: tl.constexpr):
|
|
68
|
+
for i in tl.static_range(1, n_dims + 1):
|
|
69
|
+
x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
|
|
70
|
+
return x, indices
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@triton.jit
|
|
74
|
+
def _row_id_map_pass_1_kernel(
|
|
75
|
+
# pointers
|
|
76
|
+
routing_map_ptr,
|
|
77
|
+
row_id_map_ptr,
|
|
78
|
+
workspace_ptr,
|
|
79
|
+
# sizes
|
|
80
|
+
num_tokens,
|
|
81
|
+
# strides
|
|
82
|
+
stride_routing_map_token,
|
|
83
|
+
stride_routing_map_expert,
|
|
84
|
+
stride_row_id_map_token,
|
|
85
|
+
stride_row_id_map_expert,
|
|
86
|
+
# metas
|
|
87
|
+
BLOCK_SIZE: tl.constexpr,
|
|
88
|
+
):
|
|
89
|
+
pid_m = tl.program_id(0)
|
|
90
|
+
pid_n = tl.program_id(1)
|
|
91
|
+
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
92
|
+
expert_token_mask = tl.load(
|
|
93
|
+
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
|
|
94
|
+
mask=(offset < num_tokens),
|
|
95
|
+
other=0,
|
|
96
|
+
).to(tl.int32)
|
|
97
|
+
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
|
|
98
|
+
tl.store(
|
|
99
|
+
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
|
|
100
|
+
row_id_within_token_block,
|
|
101
|
+
mask=offset < num_tokens,
|
|
102
|
+
)
|
|
103
|
+
n_tokens_per_block = tl.sum(expert_token_mask)
|
|
104
|
+
tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@triton.jit
|
|
108
|
+
def _row_id_map_pass_2_kernel(
|
|
109
|
+
# pointers
|
|
110
|
+
row_id_map_ptr,
|
|
111
|
+
workspace_ptr,
|
|
112
|
+
# sizes
|
|
113
|
+
num_tokens,
|
|
114
|
+
# strides
|
|
115
|
+
stride_row_id_map_token,
|
|
116
|
+
stride_row_id_map_expert,
|
|
117
|
+
# metas
|
|
118
|
+
WORKSPACE_LOAD_WIDTH: tl.constexpr,
|
|
119
|
+
BLOCK_SIZE: tl.constexpr,
|
|
120
|
+
):
|
|
121
|
+
pid_m = tl.program_id(0)
|
|
122
|
+
pid_n = tl.program_id(1)
|
|
123
|
+
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
|
|
124
|
+
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
125
|
+
row_id_within_token_block = tl.load(
|
|
126
|
+
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
|
|
127
|
+
mask=(offset < num_tokens),
|
|
128
|
+
other=0,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
|
|
132
|
+
n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
|
|
133
|
+
row_id = tl.where(
|
|
134
|
+
row_id_within_token_block == 0,
|
|
135
|
+
-1,
|
|
136
|
+
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
|
|
137
|
+
)
|
|
138
|
+
tl.store(
|
|
139
|
+
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
|
|
140
|
+
row_id,
|
|
141
|
+
mask=(offset < num_tokens),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@triton.jit
|
|
146
|
+
def _row_id_map_pass_3_kernel(
|
|
147
|
+
# pointers
|
|
148
|
+
row_id_map_ptr,
|
|
149
|
+
# sizes
|
|
150
|
+
num_experts: tl.constexpr,
|
|
151
|
+
# strides
|
|
152
|
+
stride_row_id_map_token,
|
|
153
|
+
stride_row_id_map_expert,
|
|
154
|
+
# metas
|
|
155
|
+
LOAD_SIZE: tl.constexpr,
|
|
156
|
+
):
|
|
157
|
+
pid = tl.program_id(0)
|
|
158
|
+
n_dims: tl.constexpr = _log2(LOAD_SIZE)
|
|
159
|
+
off = tl.arange(0, LOAD_SIZE)
|
|
160
|
+
row_id_map = tl.load(
|
|
161
|
+
row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
|
|
162
|
+
mask=off < num_experts,
|
|
163
|
+
other=-1,
|
|
164
|
+
)
|
|
165
|
+
n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
|
|
166
|
+
indices = off
|
|
167
|
+
sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
|
|
168
|
+
tl.store(
|
|
169
|
+
row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
|
|
170
|
+
sorted_map,
|
|
171
|
+
mask=off < n_routed,
|
|
172
|
+
)
|
|
173
|
+
tl.store(
|
|
174
|
+
row_id_map_ptr
|
|
175
|
+
+ pid * stride_row_id_map_token
|
|
176
|
+
+ (num_experts + off) * stride_row_id_map_expert,
|
|
177
|
+
indices,
|
|
178
|
+
mask=off < n_routed,
|
|
179
|
+
)
|
|
180
|
+
tl.store(
|
|
181
|
+
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
|
|
182
|
+
n_routed,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def make_row_id_map(
|
|
187
|
+
routing_map: torch.Tensor,
|
|
188
|
+
num_tokens: int,
|
|
189
|
+
num_experts: int,
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Prepare the row_id_map for the permutation.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
routing_map: torch.Tensor
|
|
197
|
+
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
|
|
198
|
+
which experts are routed to which tokens. The values in it: 1 means the token is routed to
|
|
199
|
+
this expert and 0 means not.
|
|
200
|
+
num_tokens: int
|
|
201
|
+
Number of tokens in the input tensor.
|
|
202
|
+
num_experts: int
|
|
203
|
+
Number of experts in the input tensor.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
row_id_map: torch.Tensor
|
|
208
|
+
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
|
|
209
|
+
For each token, the last item is the number of experts that are routed (n_routed).
|
|
210
|
+
The first n_routed items are the destination row indices in the permuted tokens.
|
|
211
|
+
The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
|
|
212
|
+
to the first n_routed row indices above.
|
|
213
|
+
"""
|
|
214
|
+
row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
|
|
215
|
+
block_size = 1024
|
|
216
|
+
grid = (num_experts, triton.cdiv(num_tokens, block_size))
|
|
217
|
+
workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")
|
|
218
|
+
|
|
219
|
+
# supposing num_tokens == 5, num_experts == 3, block_size == 3
|
|
220
|
+
# and we have a routing_map like this:
|
|
221
|
+
# [[1, 1, 0],
|
|
222
|
+
# [1, 0, 1],
|
|
223
|
+
# [0, 0, 1],
|
|
224
|
+
# [1, 1, 0],
|
|
225
|
+
# [0, 0, 0]]
|
|
226
|
+
|
|
227
|
+
# pass 1: block cumsum
|
|
228
|
+
# for each expert, compute the cumsum of every block_size tokens
|
|
229
|
+
# the row_id_map will be like this after pass 1 (r means useless values):
|
|
230
|
+
# [[1, 1, 0, r, r, r, r],
|
|
231
|
+
# [2, 0, 1, r, r, r, r],
|
|
232
|
+
# [0, 0, 2, r, r, r, r],
|
|
233
|
+
# [1, 1, 0, r, r, r, r],
|
|
234
|
+
# [0, 0, 0, r, r, r, r]]
|
|
235
|
+
_row_id_map_pass_1_kernel[grid](
|
|
236
|
+
routing_map,
|
|
237
|
+
row_id_map,
|
|
238
|
+
workspace_tensor,
|
|
239
|
+
num_tokens,
|
|
240
|
+
routing_map.stride(0),
|
|
241
|
+
routing_map.stride(1),
|
|
242
|
+
row_id_map.stride(0),
|
|
243
|
+
row_id_map.stride(1),
|
|
244
|
+
block_size,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# pass 2: cumsum all and process the mask
|
|
248
|
+
# process the block cumsum into the global cumsum and then into the dst row indices
|
|
249
|
+
# the row_id_map will be like this after pass 2 (r means useless value):
|
|
250
|
+
# [[ 0, 3, -1, r, r, r, r],
|
|
251
|
+
# [ 1, -1, 5, r, r, r, r],
|
|
252
|
+
# [-1, -1, 6, r, r, r, r],
|
|
253
|
+
# [ 2, 4, -1, r, r, r, r],
|
|
254
|
+
# [-1, -1, -1, r, r, r, r]]
|
|
255
|
+
_row_id_map_pass_2_kernel[grid](
|
|
256
|
+
row_id_map,
|
|
257
|
+
workspace_tensor,
|
|
258
|
+
num_tokens,
|
|
259
|
+
row_id_map.stride(0),
|
|
260
|
+
row_id_map.stride(1),
|
|
261
|
+
triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
|
|
262
|
+
block_size,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# pass 3: make the row_id_map from the sparse structure to the dense structure
|
|
266
|
+
# the row_id_map will be like this after pass 3 (r means useless value):
|
|
267
|
+
# [[3, 0, r, 1, 0, r, 2],
|
|
268
|
+
# [5, 1, r, 2, 0, r, 2],
|
|
269
|
+
# [6, r, r, 2, r, r, 1],
|
|
270
|
+
# [4, 2, r, 1, 0, r, 2],
|
|
271
|
+
# [r, r, r, r, r, r, 0]]
|
|
272
|
+
grid = (num_tokens,)
|
|
273
|
+
_row_id_map_pass_3_kernel[grid](
|
|
274
|
+
row_id_map,
|
|
275
|
+
num_experts,
|
|
276
|
+
row_id_map.stride(0),
|
|
277
|
+
row_id_map.stride(1),
|
|
278
|
+
triton.next_power_of_2(num_experts),
|
|
279
|
+
)
|
|
280
|
+
return row_id_map
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@triton.jit
|
|
284
|
+
def _permute_kernel(
|
|
285
|
+
# pointers
|
|
286
|
+
input_ptr,
|
|
287
|
+
output_ptr,
|
|
288
|
+
row_id_map_ptr,
|
|
289
|
+
probs_ptr,
|
|
290
|
+
scale_ptr,
|
|
291
|
+
permuted_probs_ptr,
|
|
292
|
+
permuted_scale_ptr,
|
|
293
|
+
# sizes
|
|
294
|
+
num_experts: tl.constexpr,
|
|
295
|
+
hidden_size: tl.constexpr,
|
|
296
|
+
scale_hidden_dim,
|
|
297
|
+
# strides
|
|
298
|
+
stride_row_id_map_token,
|
|
299
|
+
stride_row_id_map_expert,
|
|
300
|
+
stride_input_token,
|
|
301
|
+
stride_input_hidden,
|
|
302
|
+
stride_output_token,
|
|
303
|
+
stride_output_hidden,
|
|
304
|
+
stride_probs_token,
|
|
305
|
+
stride_probs_expert,
|
|
306
|
+
stride_scale_token,
|
|
307
|
+
stride_scale_hidden,
|
|
308
|
+
stride_permuted_probs_token,
|
|
309
|
+
stride_permuted_scale_token,
|
|
310
|
+
stride_permuted_scale_hidden,
|
|
311
|
+
# metas
|
|
312
|
+
PERMUTE_PROBS: tl.constexpr,
|
|
313
|
+
PERMUTE_SCALE: tl.constexpr,
|
|
314
|
+
BLOCK_SIZE: tl.constexpr,
|
|
315
|
+
):
|
|
316
|
+
pid_t = tl.program_id(0)
|
|
317
|
+
pid_h = tl.program_id(1)
|
|
318
|
+
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
319
|
+
mask = cur_off < hidden_size
|
|
320
|
+
input_off = pid_t * stride_input_token + cur_off * stride_input_hidden
|
|
321
|
+
inp = tl.load(input_ptr + input_off, mask=mask)
|
|
322
|
+
if PERMUTE_SCALE:
|
|
323
|
+
mask_scale = cur_off < scale_hidden_dim
|
|
324
|
+
scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
|
|
325
|
+
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
|
|
326
|
+
n_routed = tl.load(
|
|
327
|
+
row_id_map_ptr
|
|
328
|
+
+ pid_t * stride_row_id_map_token
|
|
329
|
+
+ num_experts * 2 * stride_row_id_map_expert
|
|
330
|
+
)
|
|
331
|
+
for idx in tl.range(n_routed):
|
|
332
|
+
dst_row = tl.load(
|
|
333
|
+
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
|
|
334
|
+
)
|
|
335
|
+
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
|
|
336
|
+
if PERMUTE_SCALE:
|
|
337
|
+
permuted_scale_off = (
|
|
338
|
+
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
|
|
339
|
+
)
|
|
340
|
+
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
|
|
341
|
+
if PERMUTE_PROBS:
|
|
342
|
+
expert_idx = tl.load(
|
|
343
|
+
row_id_map_ptr
|
|
344
|
+
+ pid_t * stride_row_id_map_token
|
|
345
|
+
+ (num_experts + idx) * stride_row_id_map_expert
|
|
346
|
+
)
|
|
347
|
+
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
|
|
348
|
+
prob = tl.load(probs_ptr + prob_off)
|
|
349
|
+
if pid_h == 0:
|
|
350
|
+
permuted_prob_off = dst_row * stride_permuted_probs_token
|
|
351
|
+
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
|
|
352
|
+
if prob == 0.0:
|
|
353
|
+
# for routing_map padding
|
|
354
|
+
# dst_row != -1 and prob == 0.0 means that this slot is padded
|
|
355
|
+
tl.store(output_ptr + output_off, 0.0, mask=mask)
|
|
356
|
+
else:
|
|
357
|
+
tl.store(output_ptr + output_off, inp, mask=mask)
|
|
358
|
+
else:
|
|
359
|
+
tl.store(output_ptr + output_off, inp, mask=mask)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
_permute_kernel = triton.autotune(
|
|
364
|
+
configs=[
|
|
365
|
+
triton.Config({"BLOCK_SIZE": 64}),
|
|
366
|
+
triton.Config({"BLOCK_SIZE": 128}),
|
|
367
|
+
triton.Config({"BLOCK_SIZE": 256}),
|
|
368
|
+
triton.Config({"BLOCK_SIZE": 512}),
|
|
369
|
+
triton.Config({"BLOCK_SIZE": 1024}),
|
|
370
|
+
triton.Config({"BLOCK_SIZE": 2048}),
|
|
371
|
+
triton.Config({"BLOCK_SIZE": 4096}),
|
|
372
|
+
],
|
|
373
|
+
key=["hidden_size"],
|
|
374
|
+
)(_permute_kernel)
|
|
375
|
+
except RuntimeError:
|
|
376
|
+
pass
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def permute_with_mask_map(
|
|
380
|
+
inp: torch.Tensor,
|
|
381
|
+
row_id_map: torch.Tensor,
|
|
382
|
+
probs: torch.Tensor,
|
|
383
|
+
scale: torch.Tensor,
|
|
384
|
+
num_tokens: int,
|
|
385
|
+
num_experts: int,
|
|
386
|
+
num_out_tokens: int,
|
|
387
|
+
hidden_size: int,
|
|
388
|
+
scale_hidden_dim: int,
|
|
389
|
+
):
|
|
390
|
+
"""
|
|
391
|
+
Permute the input tensor based on the row_id_map.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
inp: torch.Tensor
|
|
396
|
+
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
|
|
397
|
+
row_id_map: torch.Tensor
|
|
398
|
+
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
|
|
399
|
+
probs: torch.Tensor
|
|
400
|
+
The probabilities of the input tensor. If it is not None, it will be permuted.
|
|
401
|
+
scale: torch.Tensor
|
|
402
|
+
The scale of the input tensor. If it is not None, it will be permuted.
|
|
403
|
+
num_tokens: int
|
|
404
|
+
Number of tokens in the input tensor.
|
|
405
|
+
num_experts: int
|
|
406
|
+
Number of experts in the input tensor.
|
|
407
|
+
num_out_tokens: int
|
|
408
|
+
Number of tokens in the permuted tensor.
|
|
409
|
+
hidden_size: int
|
|
410
|
+
Hidden size of the input tensor.
|
|
411
|
+
scale_hidden_dim: int
|
|
412
|
+
Hidden size of the scale tensor.
|
|
413
|
+
"""
|
|
414
|
+
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
|
|
415
|
+
if probs is not None:
|
|
416
|
+
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
|
|
417
|
+
else:
|
|
418
|
+
permuted_probs = None
|
|
419
|
+
|
|
420
|
+
if scale is not None:
|
|
421
|
+
permuted_scale = torch.empty(
|
|
422
|
+
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
|
|
423
|
+
)
|
|
424
|
+
else:
|
|
425
|
+
permuted_scale = None
|
|
426
|
+
# pylint: disable=unnecessary-lambda-assignment
|
|
427
|
+
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
|
|
428
|
+
_permute_kernel[grid](
|
|
429
|
+
inp,
|
|
430
|
+
output,
|
|
431
|
+
row_id_map,
|
|
432
|
+
probs,
|
|
433
|
+
scale,
|
|
434
|
+
permuted_probs,
|
|
435
|
+
permuted_scale,
|
|
436
|
+
num_experts,
|
|
437
|
+
hidden_size,
|
|
438
|
+
scale_hidden_dim,
|
|
439
|
+
row_id_map.stride(0),
|
|
440
|
+
row_id_map.stride(1),
|
|
441
|
+
inp.stride(0),
|
|
442
|
+
inp.stride(1),
|
|
443
|
+
output.stride(0),
|
|
444
|
+
output.stride(1),
|
|
445
|
+
probs.stride(0) if probs is not None else None,
|
|
446
|
+
probs.stride(1) if probs is not None else None,
|
|
447
|
+
scale.stride(0) if scale is not None else None,
|
|
448
|
+
scale.stride(1) if scale is not None else None,
|
|
449
|
+
permuted_probs.stride(0) if permuted_probs is not None else None,
|
|
450
|
+
permuted_scale.stride(0) if permuted_scale is not None else None,
|
|
451
|
+
permuted_scale.stride(1) if permuted_scale is not None else None,
|
|
452
|
+
PERMUTE_PROBS=probs is not None,
|
|
453
|
+
PERMUTE_SCALE=scale is not None,
|
|
454
|
+
)
|
|
455
|
+
return output, permuted_scale, permuted_probs
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@triton.jit
|
|
459
|
+
def _unpermute_kernel(
|
|
460
|
+
# pointers
|
|
461
|
+
input_ptr,
|
|
462
|
+
output_ptr,
|
|
463
|
+
row_id_map_ptr,
|
|
464
|
+
merging_probs_ptr,
|
|
465
|
+
permuted_probs_ptr,
|
|
466
|
+
unpermuted_probs_ptr,
|
|
467
|
+
# sizes
|
|
468
|
+
num_experts: tl.constexpr,
|
|
469
|
+
hidden_size: tl.constexpr,
|
|
470
|
+
# strides
|
|
471
|
+
stride_row_id_map_token,
|
|
472
|
+
stride_row_id_map_expert,
|
|
473
|
+
stride_input_token,
|
|
474
|
+
stride_input_hidden,
|
|
475
|
+
stride_output_token,
|
|
476
|
+
stride_output_hidden,
|
|
477
|
+
stride_merging_probs_token,
|
|
478
|
+
stride_merging_probs_expert,
|
|
479
|
+
stride_permuted_probs_token,
|
|
480
|
+
stride_unpermuted_probs_token,
|
|
481
|
+
stride_unpermuted_probs_expert,
|
|
482
|
+
# metas
|
|
483
|
+
PROBS_LOAD_WIDTH: tl.constexpr,
|
|
484
|
+
WITH_MERGING_PROBS: tl.constexpr,
|
|
485
|
+
PERMUTE_PROBS: tl.constexpr,
|
|
486
|
+
BLOCK_SIZE: tl.constexpr,
|
|
487
|
+
):
|
|
488
|
+
data_type = input_ptr.dtype.element_ty
|
|
489
|
+
compute_type = tl.float32
|
|
490
|
+
|
|
491
|
+
pid_t = tl.program_id(0)
|
|
492
|
+
pid_h = tl.program_id(1)
|
|
493
|
+
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
494
|
+
mask = current_offset < hidden_size
|
|
495
|
+
if PERMUTE_PROBS:
|
|
496
|
+
# write 0.0 to probs_grad that are not routed
|
|
497
|
+
if pid_h == 0:
|
|
498
|
+
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
|
|
499
|
+
unpermuted_prob_off = (
|
|
500
|
+
pid_t * stride_unpermuted_probs_token
|
|
501
|
+
+ stride_unpermuted_probs_expert * map_load_off
|
|
502
|
+
)
|
|
503
|
+
tl.store(
|
|
504
|
+
unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
|
|
505
|
+
)
|
|
506
|
+
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
|
|
507
|
+
n_routed = tl.load(
|
|
508
|
+
row_id_map_ptr
|
|
509
|
+
+ pid_t * stride_row_id_map_token
|
|
510
|
+
+ num_experts * 2 * stride_row_id_map_expert
|
|
511
|
+
)
|
|
512
|
+
for idx in tl.range(n_routed):
|
|
513
|
+
src_row = tl.load(
|
|
514
|
+
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
|
|
515
|
+
)
|
|
516
|
+
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
|
|
517
|
+
inp = tl.load(input_ptr + input_off, mask=mask)
|
|
518
|
+
inp = inp.to(compute_type)
|
|
519
|
+
if WITH_MERGING_PROBS:
|
|
520
|
+
expert_idx = tl.load(
|
|
521
|
+
row_id_map_ptr
|
|
522
|
+
+ pid_t * stride_row_id_map_token
|
|
523
|
+
+ (num_experts + idx) * stride_row_id_map_expert
|
|
524
|
+
)
|
|
525
|
+
merging_prob_off = (
|
|
526
|
+
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
|
|
527
|
+
)
|
|
528
|
+
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
|
|
529
|
+
inp *= merging_prob
|
|
530
|
+
accumulator += inp
|
|
531
|
+
if PERMUTE_PROBS:
|
|
532
|
+
if pid_h == 0:
|
|
533
|
+
expert_idx = tl.load(
|
|
534
|
+
row_id_map_ptr
|
|
535
|
+
+ pid_t * stride_row_id_map_token
|
|
536
|
+
+ (num_experts + idx) * stride_row_id_map_expert
|
|
537
|
+
)
|
|
538
|
+
unpermuted_prob_off = (
|
|
539
|
+
pid_t * stride_unpermuted_probs_token
|
|
540
|
+
+ expert_idx * stride_unpermuted_probs_expert
|
|
541
|
+
)
|
|
542
|
+
permuted_prob_off = src_row * stride_permuted_probs_token
|
|
543
|
+
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
|
|
544
|
+
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
|
|
545
|
+
accumulator = accumulator.to(data_type)
|
|
546
|
+
output_off = pid_t * stride_output_token + current_offset * stride_output_hidden
|
|
547
|
+
tl.store(output_ptr + output_off, accumulator, mask=mask)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
try:
|
|
551
|
+
_unpermute_kernel = triton.autotune(
|
|
552
|
+
configs=[
|
|
553
|
+
triton.Config({"BLOCK_SIZE": 64}),
|
|
554
|
+
triton.Config({"BLOCK_SIZE": 128}),
|
|
555
|
+
triton.Config({"BLOCK_SIZE": 256}),
|
|
556
|
+
triton.Config({"BLOCK_SIZE": 512}),
|
|
557
|
+
triton.Config({"BLOCK_SIZE": 1024}),
|
|
558
|
+
triton.Config({"BLOCK_SIZE": 2048}),
|
|
559
|
+
triton.Config({"BLOCK_SIZE": 4096}),
|
|
560
|
+
],
|
|
561
|
+
key=["hidden_size"],
|
|
562
|
+
)(_unpermute_kernel)
|
|
563
|
+
except RuntimeError:
|
|
564
|
+
pass
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def unpermute_with_mask_map(
|
|
568
|
+
inp: torch.Tensor,
|
|
569
|
+
row_id_map: torch.Tensor,
|
|
570
|
+
merging_probs: torch.Tensor | None,
|
|
571
|
+
permuted_probs: torch.Tensor | None,
|
|
572
|
+
num_tokens: int,
|
|
573
|
+
num_experts: int,
|
|
574
|
+
hidden_size: int,
|
|
575
|
+
):
|
|
576
|
+
"""
|
|
577
|
+
Unpermute the input tensor based on the row_id_map.
|
|
578
|
+
|
|
579
|
+
Parameters
|
|
580
|
+
----------
|
|
581
|
+
inp: torch.Tensor
|
|
582
|
+
Input tensor of shape `[num_out_tokens, hidden_size]`.
|
|
583
|
+
row_id_map: torch.Tensor
|
|
584
|
+
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
|
|
585
|
+
merging_probs: torch.Tensor
|
|
586
|
+
The merging probabilities of the input tensor. If it is not None, it will be used as weights
|
|
587
|
+
to reduce the unpermuted tokens.
|
|
588
|
+
permuted_probs: torch.Tensor
|
|
589
|
+
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
|
|
590
|
+
num_tokens: int
|
|
591
|
+
Number of tokens in the permuted tensor.
|
|
592
|
+
num_experts: int
|
|
593
|
+
Number of experts in the permuted tensor.
|
|
594
|
+
hidden_size: int
|
|
595
|
+
Hidden size of the permuted tensor.
|
|
596
|
+
"""
|
|
597
|
+
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
|
|
598
|
+
if permuted_probs is not None:
|
|
599
|
+
unpermuted_probs = torch.empty(
|
|
600
|
+
(num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
|
|
601
|
+
)
|
|
602
|
+
else:
|
|
603
|
+
unpermuted_probs = None
|
|
604
|
+
# pylint: disable=unnecessary-lambda-assignment
|
|
605
|
+
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
|
|
606
|
+
_unpermute_kernel[grid](
|
|
607
|
+
inp,
|
|
608
|
+
output,
|
|
609
|
+
row_id_map,
|
|
610
|
+
merging_probs,
|
|
611
|
+
permuted_probs,
|
|
612
|
+
unpermuted_probs,
|
|
613
|
+
num_experts,
|
|
614
|
+
hidden_size,
|
|
615
|
+
row_id_map.stride(0),
|
|
616
|
+
row_id_map.stride(1),
|
|
617
|
+
inp.stride(0),
|
|
618
|
+
inp.stride(1),
|
|
619
|
+
output.stride(0),
|
|
620
|
+
output.stride(1),
|
|
621
|
+
merging_probs.stride(0) if merging_probs is not None else None,
|
|
622
|
+
merging_probs.stride(1) if merging_probs is not None else None,
|
|
623
|
+
permuted_probs.stride(0) if permuted_probs is not None else None,
|
|
624
|
+
unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
|
|
625
|
+
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
|
|
626
|
+
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
|
|
627
|
+
WITH_MERGING_PROBS=merging_probs is not None,
|
|
628
|
+
PERMUTE_PROBS=permuted_probs is not None,
|
|
629
|
+
)
|
|
630
|
+
return output, unpermuted_probs
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
class _moe_permute_mask_map(torch.autograd.Function):
|
|
634
|
+
"""functional Permute with mask router map"""
|
|
635
|
+
|
|
636
|
+
@staticmethod
|
|
637
|
+
def forward(
|
|
638
|
+
ctx,
|
|
639
|
+
inp: torch.Tensor,
|
|
640
|
+
routing_map: torch.Tensor,
|
|
641
|
+
num_out_tokens: int,
|
|
642
|
+
probs: torch.Tensor,
|
|
643
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
644
|
+
if not inp.numel():
|
|
645
|
+
ctx.probs = probs
|
|
646
|
+
return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
|
|
647
|
+
|
|
648
|
+
assert inp.is_cuda, "TransformerEngine needs CUDA."
|
|
649
|
+
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
|
|
650
|
+
if probs is not None:
|
|
651
|
+
assert probs.is_cuda, "TransformerEngine needs CUDA."
|
|
652
|
+
|
|
653
|
+
assert inp.size(0) == routing_map.size(0), "Permute not possible"
|
|
654
|
+
num_tokens, hidden_size = inp.size()
|
|
655
|
+
num_experts = routing_map.size(1)
|
|
656
|
+
assert (
|
|
657
|
+
num_out_tokens is not None
|
|
658
|
+
), "num_out_tokens must be provided to the fused permute function."
|
|
659
|
+
|
|
660
|
+
row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)
|
|
661
|
+
|
|
662
|
+
# todo torchao fp8
|
|
663
|
+
|
|
664
|
+
output, permuted_scale, permuted_probs = permute_with_mask_map(
|
|
665
|
+
inp,
|
|
666
|
+
row_id_map,
|
|
667
|
+
probs,
|
|
668
|
+
None,
|
|
669
|
+
num_tokens,
|
|
670
|
+
num_experts,
|
|
671
|
+
num_out_tokens,
|
|
672
|
+
hidden_size,
|
|
673
|
+
None,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
ctx.save_for_backward(row_id_map)
|
|
677
|
+
ctx.num_experts = num_experts
|
|
678
|
+
ctx.num_tokens = num_tokens
|
|
679
|
+
ctx.hidden_size = hidden_size
|
|
680
|
+
return output, row_id_map, permuted_probs
|
|
681
|
+
|
|
682
|
+
@staticmethod
|
|
683
|
+
def backward(
|
|
684
|
+
ctx,
|
|
685
|
+
permuted_act_grad: torch.Tensor,
|
|
686
|
+
_,
|
|
687
|
+
permuted_probs_grad: torch.Tensor,
|
|
688
|
+
) -> tuple[torch.Tensor, ...]:
|
|
689
|
+
# pylint: disable=missing-function-docstring
|
|
690
|
+
if not permuted_act_grad.numel():
|
|
691
|
+
return permuted_act_grad, None, None, ctx.probs
|
|
692
|
+
|
|
693
|
+
act_grad = None
|
|
694
|
+
probs_grad = None
|
|
695
|
+
if ctx.needs_input_grad[0]:
|
|
696
|
+
(row_id_map,) = ctx.saved_tensors
|
|
697
|
+
act_grad, probs_grad = unpermute_with_mask_map(
|
|
698
|
+
permuted_act_grad,
|
|
699
|
+
row_id_map,
|
|
700
|
+
None,
|
|
701
|
+
permuted_probs_grad,
|
|
702
|
+
ctx.num_tokens,
|
|
703
|
+
ctx.num_experts,
|
|
704
|
+
ctx.hidden_size,
|
|
705
|
+
)
|
|
706
|
+
if not ctx.needs_input_grad[3]:
|
|
707
|
+
probs_grad = None
|
|
708
|
+
return act_grad, None, None, probs_grad
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def moe_permute_with_probs(
|
|
712
|
+
inp: torch.Tensor,
|
|
713
|
+
probs: torch.Tensor,
|
|
714
|
+
routing_map: torch.Tensor,
|
|
715
|
+
num_out_tokens: int = -1,
|
|
716
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
717
|
+
"""
|
|
718
|
+
Permute the tokens and probs based on the routing_map.
|
|
719
|
+
Token with the same index will be grouped together.
|
|
720
|
+
Tokens with the same designated expert will be grouped together.
|
|
721
|
+
The routing_map indicates which experts were selected by each token.
|
|
722
|
+
|
|
723
|
+
Parameters
|
|
724
|
+
----------
|
|
725
|
+
inp: torch.Tensor
|
|
726
|
+
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
|
|
727
|
+
probs: torch.Tensor
|
|
728
|
+
The tensor of probabilities corresponding to the permuted tokens and is
|
|
729
|
+
of shape [num_tokens, num_experts]. It will be permuted with the tokens
|
|
730
|
+
according to the routing_map.
|
|
731
|
+
routing_map: torch.Tensor
|
|
732
|
+
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
|
|
733
|
+
The values in it: 1 means the token is routed to this expert and 0 means not.
|
|
734
|
+
num_out_tokens: int, default = -1
|
|
735
|
+
The effective output token count, representing the number of tokens not dropped.
|
|
736
|
+
By default, set to '-1', meaning no tokens are dropped.
|
|
737
|
+
"""
|
|
738
|
+
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
|
|
739
|
+
inp, routing_map, num_out_tokens, probs
|
|
740
|
+
)
|
|
741
|
+
return output, permuted_probs, row_id_map
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
@triton.jit
|
|
745
|
+
def _unpermute_bwd_with_merging_probs_kernel(
|
|
746
|
+
# pointers
|
|
747
|
+
fwd_output_grad_ptr,
|
|
748
|
+
fwd_input_grad_ptr,
|
|
749
|
+
fwd_input_ptr,
|
|
750
|
+
merging_probs_ptr,
|
|
751
|
+
merging_probs_grad_ptr,
|
|
752
|
+
row_id_map_ptr,
|
|
753
|
+
# sizes
|
|
754
|
+
num_experts: tl.constexpr,
|
|
755
|
+
hidden_size: tl.constexpr,
|
|
756
|
+
# strides
|
|
757
|
+
stride_row_id_map_token,
|
|
758
|
+
stride_row_id_map_expert,
|
|
759
|
+
stride_fwd_output_grad_token,
|
|
760
|
+
stride_fwd_output_grad_hidden,
|
|
761
|
+
stride_fwd_input_grad_token,
|
|
762
|
+
stride_fwd_input_grad_hidden,
|
|
763
|
+
stride_fwd_input_token,
|
|
764
|
+
stride_fwd_input_hidden,
|
|
765
|
+
stride_merging_probs_token,
|
|
766
|
+
stride_merging_probs_expert,
|
|
767
|
+
stride_merging_probs_grad_token,
|
|
768
|
+
stride_merging_probs_grad_expert,
|
|
769
|
+
# metas
|
|
770
|
+
PROBS_LOAD_WIDTH: tl.constexpr,
|
|
771
|
+
BLOCK_SIZE: tl.constexpr,
|
|
772
|
+
):
|
|
773
|
+
data_type = fwd_output_grad_ptr.dtype.element_ty
|
|
774
|
+
compute_type = tl.float32
|
|
775
|
+
|
|
776
|
+
pid = tl.program_id(0)
|
|
777
|
+
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
|
|
778
|
+
token_probs_grad_off = (
|
|
779
|
+
pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
|
|
780
|
+
)
|
|
781
|
+
tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
|
|
782
|
+
n_routed = tl.load(
|
|
783
|
+
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
|
|
784
|
+
)
|
|
785
|
+
for idx in tl.range(n_routed):
|
|
786
|
+
dst_row = tl.load(
|
|
787
|
+
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
|
|
788
|
+
)
|
|
789
|
+
expert_idx = tl.load(
|
|
790
|
+
row_id_map_ptr
|
|
791
|
+
+ pid * stride_row_id_map_token
|
|
792
|
+
+ (num_experts + idx) * stride_row_id_map_expert
|
|
793
|
+
)
|
|
794
|
+
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
|
|
795
|
+
current_start = 0
|
|
796
|
+
while current_start < hidden_size:
|
|
797
|
+
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
|
|
798
|
+
mask = current_offset < hidden_size
|
|
799
|
+
input_off = (
|
|
800
|
+
pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden
|
|
801
|
+
)
|
|
802
|
+
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
|
|
803
|
+
inp = inp.to(compute_type)
|
|
804
|
+
merging_prob_off = (
|
|
805
|
+
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
|
|
806
|
+
)
|
|
807
|
+
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
|
|
808
|
+
output = inp * merging_prob
|
|
809
|
+
output = output.to(data_type)
|
|
810
|
+
output_off = (
|
|
811
|
+
dst_row * stride_fwd_input_grad_token
|
|
812
|
+
+ current_offset * stride_fwd_input_grad_hidden
|
|
813
|
+
)
|
|
814
|
+
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
|
|
815
|
+
|
|
816
|
+
fwd_input_off = (
|
|
817
|
+
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
|
|
818
|
+
)
|
|
819
|
+
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
|
|
820
|
+
prob_grad_accum += fwd_input.to(compute_type) * inp
|
|
821
|
+
current_start += BLOCK_SIZE
|
|
822
|
+
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
|
|
823
|
+
probs_grad_off = (
|
|
824
|
+
pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
|
|
825
|
+
)
|
|
826
|
+
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
try:
|
|
830
|
+
_unpermute_bwd_with_merging_probs_kernel = triton.autotune(
|
|
831
|
+
configs=[
|
|
832
|
+
triton.Config({"BLOCK_SIZE": 64}),
|
|
833
|
+
triton.Config({"BLOCK_SIZE": 128}),
|
|
834
|
+
triton.Config({"BLOCK_SIZE": 256}),
|
|
835
|
+
triton.Config({"BLOCK_SIZE": 512}),
|
|
836
|
+
triton.Config({"BLOCK_SIZE": 1024}),
|
|
837
|
+
triton.Config({"BLOCK_SIZE": 2048}),
|
|
838
|
+
triton.Config({"BLOCK_SIZE": 4096}),
|
|
839
|
+
],
|
|
840
|
+
key=["hidden_size"],
|
|
841
|
+
)(_unpermute_bwd_with_merging_probs_kernel)
|
|
842
|
+
except RuntimeError:
|
|
843
|
+
pass
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
def unpermute_with_mask_map_bwd_with_merging_probs(
|
|
847
|
+
fwd_output_grad: torch.Tensor,
|
|
848
|
+
row_id_map: torch.Tensor,
|
|
849
|
+
fwd_input: torch.Tensor,
|
|
850
|
+
merging_probs: torch.Tensor,
|
|
851
|
+
num_tokens: int,
|
|
852
|
+
num_experts: int,
|
|
853
|
+
num_out_tokens: int,
|
|
854
|
+
hidden_size: int,
|
|
855
|
+
):
|
|
856
|
+
"""
|
|
857
|
+
Unpermute backward pass kernel with merging probs.
|
|
858
|
+
|
|
859
|
+
Parameters
|
|
860
|
+
----------
|
|
861
|
+
fwd_output_grad: torch.Tensor
|
|
862
|
+
The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
|
|
863
|
+
row_id_map: torch.Tensor
|
|
864
|
+
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
|
|
865
|
+
fwd_input: torch.Tensor
|
|
866
|
+
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
|
|
867
|
+
merging_probs: torch.Tensor
|
|
868
|
+
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
|
|
869
|
+
num_tokens: int
|
|
870
|
+
Number of tokens in the permuted tensor.
|
|
871
|
+
num_experts: int
|
|
872
|
+
Number of experts in the permuted tensor.
|
|
873
|
+
num_out_tokens: int
|
|
874
|
+
Number of tokens in the output tensor.
|
|
875
|
+
hidden_size: int
|
|
876
|
+
Hidden size of the output tensor.
|
|
877
|
+
"""
|
|
878
|
+
act_grad = torch.empty(
|
|
879
|
+
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
|
|
880
|
+
)
|
|
881
|
+
merging_probs_grad = torch.empty(
|
|
882
|
+
(num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
|
|
883
|
+
)
|
|
884
|
+
grid = (num_tokens,)
|
|
885
|
+
_unpermute_bwd_with_merging_probs_kernel[grid](
|
|
886
|
+
fwd_output_grad,
|
|
887
|
+
act_grad,
|
|
888
|
+
fwd_input,
|
|
889
|
+
merging_probs,
|
|
890
|
+
merging_probs_grad,
|
|
891
|
+
row_id_map,
|
|
892
|
+
num_experts,
|
|
893
|
+
hidden_size,
|
|
894
|
+
row_id_map.stride(0),
|
|
895
|
+
row_id_map.stride(1),
|
|
896
|
+
fwd_output_grad.stride(0),
|
|
897
|
+
fwd_output_grad.stride(1),
|
|
898
|
+
act_grad.stride(0),
|
|
899
|
+
act_grad.stride(1),
|
|
900
|
+
fwd_input.stride(0),
|
|
901
|
+
fwd_input.stride(1),
|
|
902
|
+
merging_probs.stride(0),
|
|
903
|
+
merging_probs.stride(1),
|
|
904
|
+
merging_probs_grad.stride(0),
|
|
905
|
+
merging_probs_grad.stride(1),
|
|
906
|
+
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
|
|
907
|
+
)
|
|
908
|
+
return act_grad, merging_probs_grad
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
class _moe_unpermute_mask_map(torch.autograd.Function):
|
|
912
|
+
"""functional Unpermute with mask router map"""
|
|
913
|
+
|
|
914
|
+
@staticmethod
|
|
915
|
+
def forward(
|
|
916
|
+
ctx,
|
|
917
|
+
inp: torch.Tensor,
|
|
918
|
+
row_id_map: torch.Tensor,
|
|
919
|
+
merging_probs: torch.Tensor | None,
|
|
920
|
+
restore_shape: torch.Size | None,
|
|
921
|
+
) -> torch.Tensor:
|
|
922
|
+
# pylint: disable=missing-function-docstring
|
|
923
|
+
if not inp.numel():
|
|
924
|
+
ctx.merging_probs = merging_probs
|
|
925
|
+
return inp
|
|
926
|
+
|
|
927
|
+
if restore_shape is None:
|
|
928
|
+
restore_shape = inp.shape
|
|
929
|
+
num_tokens, hidden_size = restore_shape
|
|
930
|
+
num_experts = (row_id_map.size(1) - 1) // 2
|
|
931
|
+
|
|
932
|
+
with_probs = merging_probs is not None
|
|
933
|
+
if with_probs:
|
|
934
|
+
assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
|
|
935
|
+
|
|
936
|
+
# Device check
|
|
937
|
+
assert inp.is_cuda, "TransformerEngine needs CUDA."
|
|
938
|
+
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
|
|
939
|
+
|
|
940
|
+
unpermuted_output, _ = unpermute_with_mask_map(
|
|
941
|
+
inp,
|
|
942
|
+
row_id_map,
|
|
943
|
+
merging_probs,
|
|
944
|
+
None,
|
|
945
|
+
num_tokens,
|
|
946
|
+
num_experts,
|
|
947
|
+
hidden_size,
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
if with_probs:
|
|
951
|
+
ctx.save_for_backward(inp, row_id_map, merging_probs)
|
|
952
|
+
else:
|
|
953
|
+
ctx.save_for_backward(row_id_map)
|
|
954
|
+
ctx.num_experts = num_experts
|
|
955
|
+
ctx.num_tokens = num_tokens
|
|
956
|
+
ctx.num_permuted_tokens = inp.size(0)
|
|
957
|
+
ctx.hidden_size = hidden_size
|
|
958
|
+
ctx.with_probs = with_probs
|
|
959
|
+
return unpermuted_output
|
|
960
|
+
|
|
961
|
+
@staticmethod
|
|
962
|
+
def backward(ctx, unpermuted_act_grad):
|
|
963
|
+
# pylint: disable=missing-function-docstring
|
|
964
|
+
if not unpermuted_act_grad.numel():
|
|
965
|
+
return unpermuted_act_grad, None, ctx.merging_probs, None
|
|
966
|
+
|
|
967
|
+
act_grad = None
|
|
968
|
+
probs_grad = None
|
|
969
|
+
if ctx.needs_input_grad[0]:
|
|
970
|
+
if ctx.with_probs:
|
|
971
|
+
fwd_input, row_id_map, merging_probs = ctx.saved_tensors
|
|
972
|
+
else:
|
|
973
|
+
(row_id_map,) = ctx.saved_tensors
|
|
974
|
+
|
|
975
|
+
if ctx.with_probs:
|
|
976
|
+
act_grad, probs_grad = (
|
|
977
|
+
unpermute_with_mask_map_bwd_with_merging_probs(
|
|
978
|
+
unpermuted_act_grad,
|
|
979
|
+
row_id_map,
|
|
980
|
+
fwd_input,
|
|
981
|
+
merging_probs,
|
|
982
|
+
ctx.num_tokens,
|
|
983
|
+
ctx.num_experts,
|
|
984
|
+
ctx.num_permuted_tokens,
|
|
985
|
+
ctx.hidden_size,
|
|
986
|
+
)
|
|
987
|
+
)
|
|
988
|
+
else:
|
|
989
|
+
act_grad, permuted_scale, _ = permute_with_mask_map(
|
|
990
|
+
unpermuted_act_grad,
|
|
991
|
+
row_id_map,
|
|
992
|
+
None,
|
|
993
|
+
None,
|
|
994
|
+
ctx.num_tokens,
|
|
995
|
+
ctx.num_experts,
|
|
996
|
+
ctx.num_permuted_tokens,
|
|
997
|
+
ctx.hidden_size,
|
|
998
|
+
None,
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
if not ctx.needs_input_grad[2]:
|
|
1002
|
+
probs_grad = None
|
|
1003
|
+
return act_grad, None, probs_grad, None
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
def moe_unpermute_mask(
|
|
1007
|
+
inp: torch.Tensor,
|
|
1008
|
+
row_id_map: torch.Tensor,
|
|
1009
|
+
merging_probs: torch.Tensor | None = None,
|
|
1010
|
+
restore_shape: torch.Size | None = None,
|
|
1011
|
+
) -> torch.Tensor:
|
|
1012
|
+
"""
|
|
1013
|
+
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
|
|
1014
|
+
corresponding probabilities.
|
|
1015
|
+
|
|
1016
|
+
Parameters
|
|
1017
|
+
----------
|
|
1018
|
+
inp: torch.Tensor
|
|
1019
|
+
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
|
|
1020
|
+
row_id_map: torch.Tensor
|
|
1021
|
+
The tensor of a mapping table for sorted indices used to unpermute the tokens,
|
|
1022
|
+
which is the second output tensor of `Permute`.
|
|
1023
|
+
merging_probs: torch.Tensor, default = None
|
|
1024
|
+
The tensor of probabilities corresponding to the permuted tokens. If provided,
|
|
1025
|
+
the unpermuted tokens will be merged with their respective probabilities.
|
|
1026
|
+
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
|
|
1027
|
+
restore_shape: torch.Size, default = None
|
|
1028
|
+
The output shape after the unpermute operation.
|
|
1029
|
+
map_type: str, default = 'mask'
|
|
1030
|
+
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
|
|
1031
|
+
Options are: 'mask', 'index'.
|
|
1032
|
+
probs: torch.Tensor, default = None
|
|
1033
|
+
Renamed to merging_probs. Keep for backward compatibility.
|
|
1034
|
+
"""
|
|
1035
|
+
return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
|