sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
|
|
27
27
|
tensor_model_parallel_all_gather,
|
28
28
|
)
|
29
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
+
DPPaddingMode,
|
30
31
|
attn_tp_all_gather,
|
32
|
+
attn_tp_all_gather_into_tensor,
|
31
33
|
dp_gather_replicate,
|
32
34
|
dp_scatter,
|
33
35
|
get_attention_dp_rank,
|
@@ -111,7 +113,8 @@ class LogitsMetadata:
|
|
111
113
|
# Number of tokens to sample per DP rank
|
112
114
|
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
113
115
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
114
|
-
|
116
|
+
# The gather mode for DP attention
|
117
|
+
dp_padding_mode: Optional[DPPaddingMode] = None
|
115
118
|
# for padding
|
116
119
|
padded_static_len: int = -1
|
117
120
|
|
@@ -163,12 +166,12 @@ class LogitsMetadata:
|
|
163
166
|
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
164
167
|
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
165
168
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
169
|
+
dp_padding_mode=DPPaddingMode.SUM_LEN,
|
166
170
|
)
|
167
171
|
|
168
|
-
def compute_dp_attention_metadata(self
|
169
|
-
|
170
|
-
|
171
|
-
return
|
172
|
+
def compute_dp_attention_metadata(self):
|
173
|
+
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
|
174
|
+
# we may use a smaller buffer in draft extend.
|
172
175
|
|
173
176
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
174
177
|
dp_rank = get_attention_dp_rank()
|
@@ -179,18 +182,9 @@ class LogitsMetadata:
|
|
179
182
|
else:
|
180
183
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
181
184
|
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
182
|
-
gathered_buffer = torch.zeros(
|
183
|
-
(
|
184
|
-
sum(self.global_num_tokens_for_logprob_cpu),
|
185
|
-
hidden_states.shape[1],
|
186
|
-
),
|
187
|
-
dtype=hidden_states.dtype,
|
188
|
-
device=hidden_states.device,
|
189
|
-
)
|
190
185
|
|
191
186
|
self.dp_local_start_pos = dp_local_start_pos
|
192
187
|
self.dp_local_num_tokens = dp_local_num_tokens
|
193
|
-
self.gathered_buffer = gathered_buffer
|
194
188
|
|
195
189
|
|
196
190
|
class LogitsProcessor(nn.Module):
|
@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
|
|
434
428
|
guarantee the given hidden_states follow this constraint.
|
435
429
|
"""
|
436
430
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
437
|
-
logits_metadata.compute_dp_attention_metadata(
|
431
|
+
logits_metadata.compute_dp_attention_metadata()
|
438
432
|
hidden_states, local_hidden_states = (
|
439
433
|
torch.empty_like(logits_metadata.gathered_buffer),
|
440
434
|
hidden_states,
|
@@ -463,15 +457,31 @@ class LogitsProcessor(nn.Module):
|
|
463
457
|
|
464
458
|
if self.do_tensor_parallel_all_gather:
|
465
459
|
if self.use_attn_tp_group:
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
460
|
+
if self.config.vocab_size % self.attn_tp_size == 0:
|
461
|
+
global_logits = torch.empty(
|
462
|
+
(
|
463
|
+
self.attn_tp_size,
|
464
|
+
logits.shape[0],
|
465
|
+
self.config.vocab_size // self.attn_tp_size,
|
466
|
+
),
|
467
|
+
device=logits.device,
|
468
|
+
dtype=logits.dtype,
|
469
|
+
)
|
470
|
+
attn_tp_all_gather_into_tensor(global_logits, logits)
|
471
|
+
global_logits = global_logits.permute(1, 0, 2).reshape(
|
472
|
+
logits.shape[0], self.config.vocab_size
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
global_logits = torch.empty(
|
476
|
+
(self.config.vocab_size, logits.shape[0]),
|
477
|
+
device=logits.device,
|
478
|
+
dtype=logits.dtype,
|
479
|
+
)
|
480
|
+
global_logits = global_logits.T
|
481
|
+
attn_tp_all_gather(
|
482
|
+
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
|
483
|
+
logits,
|
484
|
+
)
|
475
485
|
logits = global_logits
|
476
486
|
else:
|
477
487
|
logits = tensor_model_parallel_all_gather(logits)
|
@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel(
|
|
236
236
|
):
|
237
237
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
238
238
|
|
239
|
-
|
239
|
+
src_idx_int32 = tl.program_id(0)
|
240
|
+
src_idx = src_idx_int32.to(tl.int64)
|
240
241
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
241
242
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
242
243
|
src_ptr = input_ptr + src_idx * hidden_size
|
@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel(
|
|
255
256
|
else:
|
256
257
|
scale = 1.0
|
257
258
|
|
258
|
-
|
259
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
260
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
259
261
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
260
262
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
261
263
|
offset = start_offset + vec
|
@@ -1,17 +1,13 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import List, Optional, Tuple
|
3
3
|
|
4
|
-
import einops
|
5
4
|
import torch
|
6
|
-
from torch.nn import Module
|
7
5
|
|
8
|
-
from sglang.srt.custom_op import CustomOp
|
9
6
|
from sglang.srt.distributed import (
|
10
7
|
get_tensor_model_parallel_rank,
|
11
8
|
get_tensor_model_parallel_world_size,
|
12
9
|
)
|
13
10
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
14
|
-
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
15
11
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
12
|
ep_gather,
|
17
13
|
ep_scatter,
|
@@ -27,22 +23,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
27
23
|
silu_and_mul_triton_kernel,
|
28
24
|
tma_align_input_scale,
|
29
25
|
)
|
30
|
-
from sglang.srt.layers.moe.fused_moe_triton import
|
31
|
-
from sglang.srt.layers.moe.
|
32
|
-
from sglang.srt.layers.moe.topk import select_experts
|
26
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
27
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
33
28
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
34
29
|
from sglang.srt.layers.quantization.base_config import (
|
35
30
|
QuantizationConfig,
|
36
31
|
QuantizeMethodBase,
|
37
32
|
)
|
38
|
-
from sglang.srt.layers.quantization.fp8 import
|
33
|
+
from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
|
39
34
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
40
35
|
is_fp8_fnuz,
|
41
|
-
scaled_fp8_quant,
|
42
36
|
sglang_per_token_group_quant_fp8,
|
43
37
|
sglang_per_token_quant_fp8,
|
44
38
|
)
|
45
|
-
from sglang.srt.layers.quantization.
|
39
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
|
46
40
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
47
41
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -53,7 +47,6 @@ from sglang.srt.utils import (
|
|
53
47
|
get_bool_env_var,
|
54
48
|
is_hip,
|
55
49
|
is_npu,
|
56
|
-
set_weight_attrs,
|
57
50
|
)
|
58
51
|
|
59
52
|
_is_hip = is_hip()
|
@@ -61,14 +54,11 @@ _is_npu = is_npu()
|
|
61
54
|
_is_fp8_fnuz = is_fp8_fnuz()
|
62
55
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
63
56
|
|
64
|
-
if not _is_npu:
|
57
|
+
if not (_is_npu or _is_hip):
|
65
58
|
from sgl_kernel import silu_and_mul
|
66
59
|
|
67
60
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
68
61
|
|
69
|
-
if _is_hip:
|
70
|
-
from vllm._custom_ops import scaled_fp8_quant
|
71
|
-
|
72
62
|
if _use_aiter:
|
73
63
|
from aiter import ActivationType, QuantType
|
74
64
|
from aiter.fused_moe import fused_moe
|
@@ -165,16 +155,9 @@ class EPMoE(torch.nn.Module):
|
|
165
155
|
intermediate_size: int,
|
166
156
|
layer_id: int,
|
167
157
|
params_dtype: Optional[torch.dtype] = None,
|
168
|
-
renormalize: bool = True,
|
169
|
-
use_grouped_topk: bool = False,
|
170
|
-
num_expert_group: Optional[int] = None,
|
171
|
-
num_fused_shared_experts: int = 0,
|
172
|
-
topk_group: Optional[int] = None,
|
173
158
|
quant_config: Optional[QuantizationConfig] = None,
|
174
159
|
tp_size: Optional[int] = None,
|
175
160
|
prefix: str = "",
|
176
|
-
correction_bias: Optional[torch.Tensor] = None,
|
177
|
-
custom_routing_function: Optional[Callable] = None,
|
178
161
|
activation: str = "silu",
|
179
162
|
routed_scaling_factor: Optional[float] = None,
|
180
163
|
use_per_token_if_dynamic: bool = True,
|
@@ -192,24 +175,12 @@ class EPMoE(torch.nn.Module):
|
|
192
175
|
self.layer_id = layer_id
|
193
176
|
self.num_experts = num_experts
|
194
177
|
assert self.num_experts % self.tp_size == 0
|
195
|
-
assert (
|
196
|
-
num_fused_shared_experts == 0
|
197
|
-
), "num_fused_shared_experts is not supported in EP"
|
198
|
-
self.num_fused_shared_experts = num_fused_shared_experts
|
199
178
|
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
200
179
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
201
180
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
202
181
|
|
203
182
|
self.top_k = top_k
|
204
183
|
self.intermediate_size = intermediate_size
|
205
|
-
self.renormalize = renormalize
|
206
|
-
self.use_grouped_topk = use_grouped_topk
|
207
|
-
if self.use_grouped_topk:
|
208
|
-
assert num_expert_group is not None and topk_group is not None
|
209
|
-
self.num_expert_group = num_expert_group
|
210
|
-
self.topk_group = topk_group
|
211
|
-
self.correction_bias = correction_bias
|
212
|
-
self.custom_routing_function = custom_routing_function
|
213
184
|
self.activation = activation
|
214
185
|
self.routed_scaling_factor = routed_scaling_factor
|
215
186
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
@@ -314,33 +285,24 @@ class EPMoE(torch.nn.Module):
|
|
314
285
|
)
|
315
286
|
return (local_num_experts, expert_map)
|
316
287
|
|
317
|
-
def forward(self, hidden_states: torch.Tensor,
|
288
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
318
289
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
319
|
-
return self.forward_deepgemm(hidden_states,
|
290
|
+
return self.forward_deepgemm(hidden_states, topk_output)
|
320
291
|
else:
|
321
|
-
return self.forward_normal(hidden_states,
|
292
|
+
return self.forward_normal(hidden_states, topk_output)
|
322
293
|
|
323
294
|
def forward_deepgemm(
|
324
|
-
self,
|
295
|
+
self,
|
296
|
+
hidden_states: torch.Tensor,
|
297
|
+
topk_output: TopKOutput,
|
325
298
|
):
|
326
299
|
assert self.quant_method is not None
|
327
300
|
assert self.activation == "silu"
|
328
301
|
hidden_states_shape = hidden_states.shape
|
329
302
|
hidden_states_dtype = hidden_states.dtype
|
330
303
|
hidden_states_device = hidden_states.device
|
331
|
-
|
332
|
-
|
333
|
-
router_logits=router_logits,
|
334
|
-
top_k=self.top_k,
|
335
|
-
use_grouped_topk=self.use_grouped_topk,
|
336
|
-
renormalize=self.renormalize,
|
337
|
-
topk_group=self.topk_group,
|
338
|
-
num_expert_group=self.num_expert_group,
|
339
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
340
|
-
correction_bias=self.correction_bias,
|
341
|
-
custom_routing_function=self.custom_routing_function,
|
342
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
343
|
-
)
|
304
|
+
|
305
|
+
topk_weights, topk_ids, _ = topk_output
|
344
306
|
|
345
307
|
if not self.use_block_quant:
|
346
308
|
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
@@ -472,8 +434,10 @@ class EPMoE(torch.nn.Module):
|
|
472
434
|
)
|
473
435
|
return output
|
474
436
|
|
475
|
-
def forward_normal(self, hidden_states: torch.Tensor,
|
437
|
+
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
476
438
|
assert self.quant_method is not None
|
439
|
+
topk_weights, topk_ids, _ = topk_output
|
440
|
+
|
477
441
|
hidden_states_shape = hidden_states.shape
|
478
442
|
hidden_states_dtype = hidden_states.dtype
|
479
443
|
hidden_states_device = hidden_states.device
|
@@ -484,23 +448,6 @@ class EPMoE(torch.nn.Module):
|
|
484
448
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
485
449
|
)
|
486
450
|
|
487
|
-
topk_weights, topk_ids = select_experts(
|
488
|
-
hidden_states=hidden_states,
|
489
|
-
router_logits=router_logits,
|
490
|
-
top_k=self.top_k,
|
491
|
-
use_grouped_topk=self.use_grouped_topk,
|
492
|
-
renormalize=self.renormalize,
|
493
|
-
topk_group=self.topk_group,
|
494
|
-
num_expert_group=self.num_expert_group,
|
495
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
496
|
-
correction_bias=self.correction_bias,
|
497
|
-
custom_routing_function=self.custom_routing_function,
|
498
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
499
|
-
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
500
|
-
layer_id=self.layer_id,
|
501
|
-
),
|
502
|
-
)
|
503
|
-
|
504
451
|
if self.use_w4afp8:
|
505
452
|
local_topk_ids = topk_ids
|
506
453
|
if self.expert_map is not None:
|
@@ -904,324 +851,6 @@ class EPMoE(torch.nn.Module):
|
|
904
851
|
param_data[expert_id] = loaded_weight
|
905
852
|
|
906
853
|
|
907
|
-
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
908
|
-
|
909
|
-
def create_weights(
|
910
|
-
self,
|
911
|
-
layer: torch.nn.Module,
|
912
|
-
num_experts_per_partition: int,
|
913
|
-
hidden_size: int,
|
914
|
-
intermediate_size: int,
|
915
|
-
params_dtype: torch.dtype,
|
916
|
-
**extra_weight_attrs,
|
917
|
-
):
|
918
|
-
# Fused gate_up_proj (column parallel)
|
919
|
-
w13_weight = torch.nn.Parameter(
|
920
|
-
torch.empty(
|
921
|
-
num_experts_per_partition,
|
922
|
-
2 * intermediate_size,
|
923
|
-
hidden_size,
|
924
|
-
dtype=params_dtype,
|
925
|
-
),
|
926
|
-
requires_grad=False,
|
927
|
-
)
|
928
|
-
layer.register_parameter("w13_weight", w13_weight)
|
929
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
930
|
-
|
931
|
-
# down_proj (row parallel)
|
932
|
-
w2_weight = torch.nn.Parameter(
|
933
|
-
torch.empty(
|
934
|
-
num_experts_per_partition,
|
935
|
-
hidden_size,
|
936
|
-
intermediate_size,
|
937
|
-
dtype=params_dtype,
|
938
|
-
),
|
939
|
-
requires_grad=False,
|
940
|
-
)
|
941
|
-
layer.register_parameter("w2_weight", w2_weight)
|
942
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
943
|
-
|
944
|
-
# scale
|
945
|
-
layer.register_parameter("w13_input_scale", None)
|
946
|
-
layer.register_parameter("w13_weight_scale", None)
|
947
|
-
|
948
|
-
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
949
|
-
|
950
|
-
w2_input_scale = torch.nn.Parameter(
|
951
|
-
ones_tensor,
|
952
|
-
requires_grad=False,
|
953
|
-
)
|
954
|
-
layer.register_parameter("w2_input_scale", w2_input_scale)
|
955
|
-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
956
|
-
|
957
|
-
w2_weight_scale = torch.nn.Parameter(
|
958
|
-
ones_tensor,
|
959
|
-
requires_grad=False,
|
960
|
-
)
|
961
|
-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
962
|
-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
963
|
-
|
964
|
-
def apply(
|
965
|
-
self,
|
966
|
-
layer: torch.nn.Module,
|
967
|
-
x: torch.Tensor,
|
968
|
-
router_logits: torch.Tensor,
|
969
|
-
top_k: int,
|
970
|
-
renormalize: bool,
|
971
|
-
use_grouped_topk: bool,
|
972
|
-
topk_group: Optional[int] = None,
|
973
|
-
num_expert_group: Optional[int] = None,
|
974
|
-
custom_routing_function: Optional[Callable] = None,
|
975
|
-
) -> torch.Tensor:
|
976
|
-
raise NotImplementedError
|
977
|
-
|
978
|
-
|
979
|
-
class Fp8EPMoEMethod(Fp8MoEMethod):
|
980
|
-
"""MoE method for FP8.
|
981
|
-
Supports loading FP8 checkpoints with static weight scale and
|
982
|
-
dynamic/static activation scale.
|
983
|
-
|
984
|
-
Args:
|
985
|
-
quant_config: The quantization config.
|
986
|
-
"""
|
987
|
-
|
988
|
-
def __init__(self, quant_config: Fp8Config):
|
989
|
-
self.quant_config = quant_config
|
990
|
-
self.block_quant = self.quant_config.weight_block_size is not None
|
991
|
-
|
992
|
-
def create_weights(
|
993
|
-
self,
|
994
|
-
layer: Module,
|
995
|
-
num_experts_per_partition: int,
|
996
|
-
hidden_size: int,
|
997
|
-
intermediate_size: int,
|
998
|
-
params_dtype: torch.dtype,
|
999
|
-
**extra_weight_attrs,
|
1000
|
-
):
|
1001
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
1002
|
-
params_dtype = torch.float8_e4m3fn
|
1003
|
-
|
1004
|
-
tp_size = get_tensor_model_parallel_world_size()
|
1005
|
-
if self.block_quant:
|
1006
|
-
block_n, block_k = (
|
1007
|
-
self.quant_config.weight_block_size[0],
|
1008
|
-
self.quant_config.weight_block_size[1],
|
1009
|
-
)
|
1010
|
-
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
1011
|
-
# Required by column parallel or enabling merged weights
|
1012
|
-
if intermediate_size % block_n != 0:
|
1013
|
-
raise ValueError(
|
1014
|
-
f"The output_size of gate's and up's weight = "
|
1015
|
-
f"{intermediate_size} is not divisible by "
|
1016
|
-
f"weight quantization block_n = {block_n}."
|
1017
|
-
)
|
1018
|
-
if tp_size > 1:
|
1019
|
-
# Required by row parallel
|
1020
|
-
if intermediate_size % block_k != 0:
|
1021
|
-
raise ValueError(
|
1022
|
-
f"The input_size of down's weight = "
|
1023
|
-
f"{intermediate_size} is not divisible by "
|
1024
|
-
f"weight quantization block_k = {block_k}."
|
1025
|
-
)
|
1026
|
-
|
1027
|
-
# WEIGHTS
|
1028
|
-
w13_weight = torch.nn.Parameter(
|
1029
|
-
torch.empty(
|
1030
|
-
num_experts_per_partition,
|
1031
|
-
2 * intermediate_size,
|
1032
|
-
hidden_size,
|
1033
|
-
dtype=params_dtype,
|
1034
|
-
),
|
1035
|
-
requires_grad=False,
|
1036
|
-
)
|
1037
|
-
layer.register_parameter("w13_weight", w13_weight)
|
1038
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
1039
|
-
|
1040
|
-
w2_weight = torch.nn.Parameter(
|
1041
|
-
torch.empty(
|
1042
|
-
num_experts_per_partition,
|
1043
|
-
hidden_size,
|
1044
|
-
intermediate_size,
|
1045
|
-
dtype=params_dtype,
|
1046
|
-
),
|
1047
|
-
requires_grad=False,
|
1048
|
-
)
|
1049
|
-
layer.register_parameter("w2_weight", w2_weight)
|
1050
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
1051
|
-
|
1052
|
-
# WEIGHT_SCALES
|
1053
|
-
if self.block_quant:
|
1054
|
-
w13_weight_scale = torch.nn.Parameter(
|
1055
|
-
torch.ones(
|
1056
|
-
num_experts_per_partition,
|
1057
|
-
2 * ((intermediate_size + block_n - 1) // block_n),
|
1058
|
-
(hidden_size + block_k - 1) // block_k,
|
1059
|
-
dtype=torch.float32,
|
1060
|
-
),
|
1061
|
-
requires_grad=False,
|
1062
|
-
)
|
1063
|
-
w2_weight_scale = torch.nn.Parameter(
|
1064
|
-
torch.ones(
|
1065
|
-
num_experts_per_partition,
|
1066
|
-
(hidden_size + block_n - 1) // block_n,
|
1067
|
-
(intermediate_size + block_k - 1) // block_k,
|
1068
|
-
dtype=torch.float32,
|
1069
|
-
),
|
1070
|
-
requires_grad=False,
|
1071
|
-
)
|
1072
|
-
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
1073
|
-
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
1074
|
-
assert self.quant_config.activation_scheme == "dynamic"
|
1075
|
-
else:
|
1076
|
-
# WEIGHT_SCALES
|
1077
|
-
# Allocate 2 scales for w1 and w3 respectively.
|
1078
|
-
w13_weight_scale = torch.nn.Parameter(
|
1079
|
-
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
1080
|
-
requires_grad=False,
|
1081
|
-
)
|
1082
|
-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
1083
|
-
|
1084
|
-
w2_weight_scale = torch.nn.Parameter(
|
1085
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1086
|
-
requires_grad=False,
|
1087
|
-
)
|
1088
|
-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
1089
|
-
# Add the quantization method used (per tensor/grouped/channel)
|
1090
|
-
# to ensure the weight scales are loaded in properly
|
1091
|
-
extra_weight_attrs.update(
|
1092
|
-
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
1093
|
-
if self.block_quant
|
1094
|
-
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
1095
|
-
)
|
1096
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
1097
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
1098
|
-
# process_weights_after_loading()
|
1099
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
1100
|
-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
1101
|
-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
1102
|
-
|
1103
|
-
# INPUT_SCALES
|
1104
|
-
if self.quant_config.activation_scheme == "static":
|
1105
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1106
|
-
raise ValueError(
|
1107
|
-
"Found static activation scheme for checkpoint that "
|
1108
|
-
"was not serialized fp8."
|
1109
|
-
)
|
1110
|
-
|
1111
|
-
w13_input_scale = torch.nn.Parameter(
|
1112
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1113
|
-
requires_grad=False,
|
1114
|
-
)
|
1115
|
-
layer.register_parameter("w13_input_scale", w13_input_scale)
|
1116
|
-
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
1117
|
-
|
1118
|
-
w2_input_scale = torch.nn.Parameter(
|
1119
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1120
|
-
requires_grad=False,
|
1121
|
-
)
|
1122
|
-
layer.register_parameter("w2_input_scale", w2_input_scale)
|
1123
|
-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
1124
|
-
|
1125
|
-
else:
|
1126
|
-
layer.w13_input_scale = None
|
1127
|
-
layer.w2_input_scale = None
|
1128
|
-
|
1129
|
-
def process_weights_after_loading(self, layer: Module) -> None:
|
1130
|
-
|
1131
|
-
# If checkpoint is fp16, quantize in place.
|
1132
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1133
|
-
# If rocm, use float8_e4m3fnuz as dtype
|
1134
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1135
|
-
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
1136
|
-
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
1137
|
-
|
1138
|
-
layer.w13_weight_scale = torch.nn.Parameter(
|
1139
|
-
torch.ones(
|
1140
|
-
layer.num_experts_per_partition,
|
1141
|
-
dtype=torch.float32,
|
1142
|
-
device=w13_weight.device,
|
1143
|
-
),
|
1144
|
-
requires_grad=False,
|
1145
|
-
)
|
1146
|
-
|
1147
|
-
for expert in range(layer.num_experts_per_partition):
|
1148
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
1149
|
-
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
1150
|
-
)
|
1151
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
1152
|
-
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
1153
|
-
)
|
1154
|
-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
1155
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1156
|
-
return
|
1157
|
-
|
1158
|
-
# If checkpoint is fp8, we need to handle that the
|
1159
|
-
# MoE kernels require single activation scale and single weight
|
1160
|
-
# scale for w13 per expert.
|
1161
|
-
else:
|
1162
|
-
if self.quant_config.activation_scheme == "static":
|
1163
|
-
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
1164
|
-
raise ValueError(
|
1165
|
-
"QuantConfig has static quantization, but found "
|
1166
|
-
"activation scales are None."
|
1167
|
-
)
|
1168
|
-
layer.w13_weight_scale = torch.nn.Parameter(
|
1169
|
-
torch.max(layer.w13_weight_scale, dim=1).values,
|
1170
|
-
requires_grad=False,
|
1171
|
-
)
|
1172
|
-
if self.block_quant:
|
1173
|
-
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1174
|
-
if _is_fp8_fnuz:
|
1175
|
-
# activation_scheme: dynamic
|
1176
|
-
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1177
|
-
weight=layer.w13_weight,
|
1178
|
-
weight_scale=layer.w13_weight_scale_inv,
|
1179
|
-
input_scale=None,
|
1180
|
-
)
|
1181
|
-
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1182
|
-
weight=layer.w2_weight,
|
1183
|
-
weight_scale=layer.w2_weight_scale_inv,
|
1184
|
-
input_scale=None,
|
1185
|
-
)
|
1186
|
-
# Reset the parameter
|
1187
|
-
layer.w13_weight = torch.nn.Parameter(
|
1188
|
-
w13_weight, requires_grad=False
|
1189
|
-
)
|
1190
|
-
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1191
|
-
w13_weight_scale, requires_grad=False
|
1192
|
-
)
|
1193
|
-
layer.w13_input_scale = None
|
1194
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1195
|
-
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1196
|
-
w2_weight_scale, requires_grad=False
|
1197
|
-
)
|
1198
|
-
layer.w2_input_scale = None
|
1199
|
-
if _use_aiter:
|
1200
|
-
layer.w13_weight = torch.nn.Parameter(
|
1201
|
-
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1202
|
-
requires_grad=False,
|
1203
|
-
)
|
1204
|
-
layer.w2_weight = torch.nn.Parameter(
|
1205
|
-
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1206
|
-
requires_grad=False,
|
1207
|
-
)
|
1208
|
-
return
|
1209
|
-
|
1210
|
-
def apply(
|
1211
|
-
self,
|
1212
|
-
layer: torch.nn.Module,
|
1213
|
-
x: torch.Tensor,
|
1214
|
-
router_logits: torch.Tensor,
|
1215
|
-
top_k: int,
|
1216
|
-
renormalize: bool,
|
1217
|
-
use_grouped_topk: bool,
|
1218
|
-
topk_group: Optional[int] = None,
|
1219
|
-
num_expert_group: Optional[int] = None,
|
1220
|
-
custom_routing_function: Optional[Callable] = None,
|
1221
|
-
) -> torch.Tensor:
|
1222
|
-
raise NotImplementedError
|
1223
|
-
|
1224
|
-
|
1225
854
|
class DeepEPMoE(EPMoE):
|
1226
855
|
"""
|
1227
856
|
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
@@ -1237,16 +866,9 @@ class DeepEPMoE(EPMoE):
|
|
1237
866
|
intermediate_size: int,
|
1238
867
|
layer_id: int,
|
1239
868
|
params_dtype: Optional[torch.dtype] = None,
|
1240
|
-
renormalize: bool = True,
|
1241
|
-
use_grouped_topk: bool = False,
|
1242
|
-
num_expert_group: Optional[int] = None,
|
1243
|
-
num_fused_shared_experts: int = 0,
|
1244
|
-
topk_group: Optional[int] = None,
|
1245
869
|
quant_config: Optional[QuantizationConfig] = None,
|
1246
870
|
tp_size: Optional[int] = None,
|
1247
871
|
prefix: str = "",
|
1248
|
-
correction_bias: Optional[torch.Tensor] = None,
|
1249
|
-
custom_routing_function: Optional[Callable] = None,
|
1250
872
|
activation: str = "silu",
|
1251
873
|
routed_scaling_factor: Optional[float] = None,
|
1252
874
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
@@ -1258,20 +880,19 @@ class DeepEPMoE(EPMoE):
|
|
1258
880
|
intermediate_size=intermediate_size,
|
1259
881
|
layer_id=layer_id,
|
1260
882
|
params_dtype=params_dtype,
|
1261
|
-
renormalize=renormalize,
|
1262
|
-
use_grouped_topk=use_grouped_topk,
|
1263
|
-
num_expert_group=num_expert_group,
|
1264
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1265
|
-
topk_group=topk_group,
|
1266
883
|
quant_config=quant_config,
|
1267
884
|
tp_size=tp_size,
|
1268
885
|
prefix=prefix,
|
1269
|
-
correction_bias=correction_bias,
|
1270
|
-
custom_routing_function=custom_routing_function,
|
1271
886
|
activation=activation,
|
1272
887
|
routed_scaling_factor=routed_scaling_factor,
|
1273
888
|
)
|
1274
889
|
self.deepep_mode = deepep_mode
|
890
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
891
|
+
assert self.use_fp8_w8a8, (
|
892
|
+
"DeepGEMM requires an fp8_w8a8 model; "
|
893
|
+
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
|
894
|
+
)
|
895
|
+
|
1275
896
|
if self.deepep_mode.enable_low_latency():
|
1276
897
|
assert (
|
1277
898
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|