vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,479 @@
1
+ """Fused MoE kernel."""
2
+ import functools
3
+ import json
4
+ import os
5
+ from typing import Any, Dict, Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from vllm import _custom_ops as ops
12
+ from vllm.logger import init_logger
13
+ from vllm.utils import is_hip
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ @triton.jit
19
+ def fused_moe_kernel(
20
+ # Pointers to matrices
21
+ a_ptr,
22
+ b_ptr,
23
+ c_ptr,
24
+ a_scale_ptr,
25
+ b_scale_ptr,
26
+ topk_weights_ptr,
27
+ sorted_token_ids_ptr,
28
+ expert_ids_ptr,
29
+ num_tokens_post_padded_ptr,
30
+ # Matrix dimensions
31
+ N,
32
+ K,
33
+ EM,
34
+ num_valid_tokens,
35
+ # The stride variables represent how much to increase the ptr by when
36
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
37
+ # how much to increase `a_ptr` by to get the element one row down
38
+ # (A has M rows).
39
+ stride_am,
40
+ stride_ak,
41
+ stride_be,
42
+ stride_bk,
43
+ stride_bn,
44
+ stride_cm,
45
+ stride_cn,
46
+ # Meta-parameters
47
+ BLOCK_SIZE_M: tl.constexpr,
48
+ BLOCK_SIZE_N: tl.constexpr,
49
+ BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr,
51
+ MUL_ROUTED_WEIGHT: tl.constexpr,
52
+ top_k: tl.constexpr,
53
+ compute_type: tl.constexpr,
54
+ use_fp8: tl.constexpr,
55
+ ):
56
+ """
57
+ Implements the fused computation for a Mixture of Experts (MOE) using
58
+ token and expert matrices.
59
+
60
+ Key Parameters:
61
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
62
+ be any shape representing batches and K is the feature dimension of
63
+ each token.
64
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
65
+ the number of experts, K is the input feature dimension, and N is
66
+ the output feature dimension.
67
+ - C: The output cache tensor with shape (M, topk, N), where M is the
68
+ total number of tokens post padding, topk is the number of times
69
+ each token is repeated, and N is the output feature dimension.
70
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
71
+ repeated topk times and arranged by the expert index they are
72
+ assigned to.
73
+ - expert_ids: A tensor containing the indices of the expert for each
74
+ block. It determines which expert matrix from B should be used for
75
+ each block in A.
76
+ This kernel performs the multiplication of a token by its corresponding
77
+ expert matrix as determined by `expert_ids`. The sorting of
78
+ `sorted_token_ids` by expert index and padding ensures divisibility by
79
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
80
+ multiplication across different blocks processed by the same expert.
81
+ """
82
+ # -----------------------------------------------------------
83
+ # Map program ids `pid` to the block of C it should compute.
84
+ # This is done in a grouped ordering to promote L2 data reuse.
85
+ pid = tl.program_id(axis=0)
86
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
87
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
88
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
89
+ group_id = pid // num_pid_in_group
90
+ first_pid_m = group_id * GROUP_SIZE_M
91
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
92
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
93
+ pid_n = (pid % num_pid_in_group) // group_size_m
94
+
95
+ # ----------------------------------------------------------
96
+ # Create pointers for the first blocks of A and B.
97
+ # We will advance this pointer as we move in the K direction
98
+ # and accumulate
99
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
100
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
101
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
102
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
103
+ return
104
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
105
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
106
+ token_mask = offs_token < num_valid_tokens
107
+
108
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
109
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
110
+ a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
111
+ offs_k[None, :] * stride_ak)
112
+
113
+ off_experts = tl.load(expert_ids_ptr + pid_m)
114
+ b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
115
+ offs_bn[None, :] * stride_bn)
116
+
117
+ if use_fp8:
118
+ a_scale = tl.load(a_scale_ptr)
119
+ b_scale = tl.load(b_scale_ptr + off_experts)
120
+
121
+ # -----------------------------------------------------------
122
+ # Iterate to compute a block of the C matrix.
123
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
124
+ # of fp32 values for higher accuracy.
125
+ # `accumulator` will be converted back to fp16 after the loop.
126
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
127
+
128
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
129
+ # Load the next block of A and B, generate a mask by checking the
130
+ # K dimension.
131
+ a = tl.load(a_ptrs,
132
+ mask=token_mask[:, None] &
133
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
134
+ other=0.0)
135
+ b = tl.load(b_ptrs,
136
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
137
+ other=0.0)
138
+ # We accumulate along the K dimension.
139
+ if use_fp8:
140
+ accumulator = tl.dot(a, b, acc=accumulator)
141
+ else:
142
+ accumulator += tl.dot(a, b)
143
+ # Advance the ptrs to the next K block.
144
+ a_ptrs += BLOCK_SIZE_K * stride_ak
145
+ b_ptrs += BLOCK_SIZE_K * stride_bk
146
+
147
+ if MUL_ROUTED_WEIGHT:
148
+ moe_weight = tl.load(topk_weights_ptr + offs_token,
149
+ mask=token_mask,
150
+ other=0)
151
+ accumulator = accumulator * moe_weight[:, None]
152
+
153
+ if use_fp8:
154
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
155
+ else:
156
+ accumulator = accumulator.to(compute_type)
157
+ # -----------------------------------------------------------
158
+ # Write back the block of the output
159
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
160
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
161
+ None, :]
162
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
163
+ tl.store(c_ptrs, accumulator, mask=c_mask)
164
+
165
+
166
+ def moe_align_block_size(
167
+ topk_ids: torch.Tensor, block_size: int,
168
+ num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """
170
+ Aligns the token distribution across experts to be compatible with block
171
+ size for matrix multiplication.
172
+
173
+ Parameters:
174
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
175
+ top-k expert indices for each token.
176
+ - block_size: The block size used in block matrix multiplication.
177
+ - num_experts: The total number of experts.
178
+
179
+ Returns:
180
+ - sorted_token_ids: A tensor containing the sorted token indices according
181
+ to their allocated expert.
182
+ - expert_ids: A tensor indicating the assigned expert index for each block.
183
+ - num_tokens_post_padded: The total number of tokens after padding,
184
+ ensuring divisibility by block_size.
185
+
186
+ This function pads the number of tokens that each expert needs to process
187
+ so that it is divisible by block_size.
188
+ Padding ensures that during block matrix multiplication, the dimensions
189
+ align correctly.
190
+
191
+ Example:
192
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
193
+ block_size = 4, and num_experts = 4:
194
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
195
+ with each expert needing to process 3 tokens.
196
+ - As block_size is 4, we pad 1 token for each expert.
197
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
198
+ - Then append padding tokens [12, 12, 12, 12] for each block.
199
+ - After sorting by expert index, we obtain token_ids
200
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
201
+ Tokens 12 are non-existent (padding) and are ignored in
202
+ the subsequent matrix multiplication.
203
+ - The padding ensures that the total number of tokens is now divisible
204
+ by block_size for proper block matrix operations.
205
+ """
206
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
207
+ sorted_ids = torch.empty((max_num_tokens_padded, ),
208
+ dtype=torch.int32,
209
+ device=topk_ids.device)
210
+ sorted_ids.fill_(topk_ids.numel())
211
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
212
+ expert_ids = torch.empty((max_num_m_blocks, ),
213
+ dtype=torch.int32,
214
+ device=topk_ids.device)
215
+ num_tokens_post_pad = torch.empty((1),
216
+ dtype=torch.int32,
217
+ device=topk_ids.device)
218
+ ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
219
+ expert_ids, num_tokens_post_pad)
220
+ return sorted_ids, expert_ids, num_tokens_post_pad
221
+
222
+
223
+ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
224
+ A_scale: Optional[torch.Tensor],
225
+ B_scale: Optional[torch.Tensor],
226
+ topk_weights: torch.Tensor, topk_ids: torch.Tensor,
227
+ sorted_token_ids: torch.Tensor,
228
+ expert_ids: torch.Tensor,
229
+ num_tokens_post_padded: torch.Tensor,
230
+ mul_routed_weight: bool, top_k: int,
231
+ config: Dict[str, Any], compute_type: tl.dtype,
232
+ use_fp8: bool) -> None:
233
+ assert topk_weights.stride(1) == 1
234
+ assert sorted_token_ids.stride(0) == 1
235
+
236
+ if not use_fp8:
237
+ assert A_scale is None
238
+ assert B_scale is None
239
+ else:
240
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
241
+ assert B_scale is not None
242
+
243
+ grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
244
+ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
245
+
246
+ fused_moe_kernel[grid](
247
+ A,
248
+ B,
249
+ C,
250
+ A_scale,
251
+ B_scale,
252
+ topk_weights,
253
+ sorted_token_ids,
254
+ expert_ids,
255
+ num_tokens_post_padded,
256
+ B.shape[1],
257
+ B.shape[2],
258
+ sorted_token_ids.shape[0],
259
+ topk_ids.numel(),
260
+ A.stride(0),
261
+ A.stride(1),
262
+ B.stride(0),
263
+ B.stride(2),
264
+ B.stride(1),
265
+ C.stride(1),
266
+ C.stride(2),
267
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
268
+ top_k=top_k,
269
+ compute_type=compute_type,
270
+ use_fp8=use_fp8,
271
+ **config,
272
+ )
273
+
274
+
275
+ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
276
+ device_name = torch.cuda.get_device_name().replace(" ", "_")
277
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
278
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
279
+
280
+
281
+ @functools.lru_cache
282
+ def get_moe_configs(E: int, N: int,
283
+ dtype: Optional[str]) -> Optional[Dict[int, Any]]:
284
+ """
285
+ Return optimized configurations for the fused MoE kernel.
286
+
287
+ The return value will be a dictionary that maps an irregular grid of
288
+ batch sizes to configurations of the fused_moe kernel. To evaluate the
289
+ kernel on a given batch size bs, the closest batch size in the grid should
290
+ be picked and the associated configuration chosen to invoke the kernel.
291
+ """
292
+
293
+ # First look up if an optimized configuration is available in the configs
294
+ # directory
295
+ json_file_name = get_config_file_name(E, N, dtype)
296
+
297
+ config_file_path = os.path.join(
298
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
299
+ if os.path.exists(config_file_path):
300
+ with open(config_file_path) as f:
301
+ logger.info("Using configuration from %s for MoE layer.",
302
+ config_file_path)
303
+ # If a configuration has been found, return it
304
+ return {int(key): val for key, val in json.load(f).items()}
305
+
306
+ # If no optimized configuration is available, we will use the default
307
+ # configuration
308
+ return None
309
+
310
+
311
+ def fused_moe(
312
+ hidden_states: torch.Tensor,
313
+ w1: torch.Tensor,
314
+ w2: torch.Tensor,
315
+ gating_output: torch.Tensor,
316
+ topk: int,
317
+ renormalize: bool,
318
+ inplace: bool = False,
319
+ override_config: Optional[Dict[str, Any]] = None,
320
+ use_fp8: bool = False,
321
+ w1_scale: Optional[torch.Tensor] = None,
322
+ w2_scale: Optional[torch.Tensor] = None,
323
+ a1_scale: Optional[torch.Tensor] = None,
324
+ a2_scale: Optional[torch.Tensor] = None,
325
+ ) -> torch.Tensor:
326
+ """
327
+ This function computes a Mixture of Experts (MoE) layer using two sets of
328
+ weights, w1 and w2, and top-k gating mechanism.
329
+
330
+ Parameters:
331
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
332
+ - w1 (torch.Tensor): The first set of expert weights.
333
+ - w2 (torch.Tensor): The second set of expert weights.
334
+ - gating_output (torch.Tensor): The output of the gating operation
335
+ (before softmax).
336
+ - topk (int): The number of top-k experts to select.
337
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
338
+ - inplace (bool): If True, perform the operation in-place.
339
+ Defaults to False.
340
+ - override_config (Optional[Dict[str, Any]]): Optional override
341
+ for the kernel configuration.
342
+ - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
343
+ products for w1 and w2. Defaults to False.
344
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
345
+ w1.
346
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
347
+ w2.
348
+
349
+ Returns:
350
+ - torch.Tensor: The output tensor after applying the MoE layer.
351
+ """
352
+ # Check constraints.
353
+ assert hidden_states.shape[0] == gating_output.shape[0], (
354
+ "Number of tokens mismatch")
355
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
356
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
357
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
358
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
359
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
360
+ assert hidden_states.dtype in [
361
+ torch.float32, torch.float16, torch.bfloat16
362
+ ]
363
+ M, _ = hidden_states.shape
364
+ E, N, _ = w1.shape
365
+
366
+ if is_hip():
367
+ # The MoE kernels are not yet supported on ROCm.
368
+ routing_weights = torch.softmax(gating_output,
369
+ dim=-1,
370
+ dtype=torch.float32)
371
+ topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
372
+ else:
373
+ import vllm._moe_C as moe_kernels
374
+
375
+ topk_weights = torch.empty(M,
376
+ topk,
377
+ dtype=torch.float32,
378
+ device=hidden_states.device)
379
+ topk_ids = torch.empty(M,
380
+ topk,
381
+ dtype=torch.int32,
382
+ device=hidden_states.device)
383
+ token_expert_indicies = torch.empty(M,
384
+ topk,
385
+ dtype=torch.int32,
386
+ device=hidden_states.device)
387
+ moe_kernels.topk_softmax(
388
+ topk_weights,
389
+ topk_ids,
390
+ token_expert_indicies,
391
+ gating_output.float(), # TODO(woosuk): Optimize this.
392
+ )
393
+ del token_expert_indicies # Not used. Will be used in the future.
394
+ if renormalize:
395
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
396
+
397
+ if override_config:
398
+ config = override_config
399
+ else:
400
+ # First try to load optimal config from the file
401
+ configs = get_moe_configs(E, w2.shape[2],
402
+ "float8" if use_fp8 else None)
403
+
404
+ if configs:
405
+ # If an optimal configuration map has been found, look up the
406
+ # optimal config
407
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
408
+ else:
409
+ # Else use the default config
410
+ config = {
411
+ 'BLOCK_SIZE_M': 64,
412
+ 'BLOCK_SIZE_N': 64,
413
+ 'BLOCK_SIZE_K': 32,
414
+ 'GROUP_SIZE_M': 8
415
+ }
416
+
417
+ if M <= E:
418
+ config = {
419
+ 'BLOCK_SIZE_M': 16,
420
+ 'BLOCK_SIZE_N': 32,
421
+ 'BLOCK_SIZE_K': 64,
422
+ 'GROUP_SIZE_M': 1
423
+ }
424
+
425
+ intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
426
+ device=hidden_states.device,
427
+ dtype=hidden_states.dtype)
428
+ intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
429
+ device=hidden_states.device,
430
+ dtype=hidden_states.dtype)
431
+ intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
432
+ device=hidden_states.device,
433
+ dtype=hidden_states.dtype)
434
+
435
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
436
+ topk_ids, config['BLOCK_SIZE_M'], E)
437
+ compute_type = (tl.bfloat16
438
+ if hidden_states.dtype == torch.bfloat16 else tl.float16)
439
+
440
+ invoke_fused_moe_kernel(hidden_states,
441
+ w1,
442
+ intermediate_cache1,
443
+ a1_scale,
444
+ w1_scale,
445
+ topk_weights,
446
+ topk_ids,
447
+ sorted_token_ids,
448
+ expert_ids,
449
+ num_tokens_post_padded,
450
+ False,
451
+ topk_ids.shape[1],
452
+ config,
453
+ compute_type=compute_type,
454
+ use_fp8=use_fp8)
455
+
456
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
457
+
458
+ invoke_fused_moe_kernel(intermediate_cache2,
459
+ w2,
460
+ intermediate_cache3,
461
+ a2_scale,
462
+ w2_scale,
463
+ topk_weights,
464
+ topk_ids,
465
+ sorted_token_ids,
466
+ expert_ids,
467
+ num_tokens_post_padded,
468
+ True,
469
+ 1,
470
+ config,
471
+ compute_type=compute_type,
472
+ use_fp8=use_fp8)
473
+
474
+ if inplace:
475
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
476
+ dim=1,
477
+ out=hidden_states)
478
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
479
+ dim=1)
@@ -0,0 +1,71 @@
1
+ """Custom normalization layers."""
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from vllm import _custom_ops as ops
8
+
9
+
10
+ class RMSNorm(nn.Module):
11
+ """Root mean square normalization.
12
+
13
+ Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
14
+ Refer to https://arxiv.org/abs/1910.07467
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ hidden_size: int,
20
+ eps: float = 1e-6,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(hidden_size))
24
+ self.variance_epsilon = eps
25
+
26
+ def _forward(
27
+ self,
28
+ x: torch.Tensor,
29
+ residual: Optional[torch.Tensor] = None,
30
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
31
+ """PyTorch-native implementation equivalent to forward()."""
32
+ orig_dtype = x.dtype
33
+ x = x.to(torch.float32)
34
+ if residual is not None:
35
+ x = x + residual.to(torch.float32)
36
+ residual = x.to(orig_dtype)
37
+
38
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
39
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
40
+ x = x.to(orig_dtype) * self.weight
41
+ if residual is None:
42
+ return x
43
+ else:
44
+ return x, residual
45
+
46
+ def forward(
47
+ self,
48
+ x: torch.Tensor,
49
+ residual: Optional[torch.Tensor] = None,
50
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
51
+ if residual is not None:
52
+ ops.fused_add_rms_norm(
53
+ x,
54
+ residual,
55
+ self.weight.data,
56
+ self.variance_epsilon,
57
+ )
58
+ return x, residual
59
+ out = torch.empty_like(x)
60
+ ops.rms_norm(
61
+ out,
62
+ x,
63
+ self.weight.data,
64
+ self.variance_epsilon,
65
+ )
66
+ return out
67
+
68
+ def extra_repr(self) -> str:
69
+ s = f"hidden_size={self.weight.data.size(0)}"
70
+ s += f", eps={self.variance_epsilon}"
71
+ return s