sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
|
|
49
49
|
|
50
50
|
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
51
51
|
|
52
|
+
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
|
53
|
+
|
52
54
|
|
53
55
|
@dataclass
|
54
56
|
class EagleDraftInput:
|
@@ -177,11 +179,24 @@ class EagleDraftInput:
|
|
177
179
|
)
|
178
180
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
179
181
|
|
180
|
-
def filter_batch(self, new_indices: torch.Tensor):
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
182
|
+
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
183
|
+
if has_been_filtered:
|
184
|
+
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
|
185
|
+
# therefore, we don't need to filter the batch again in scheduler
|
186
|
+
if len(new_indices) != len(self.topk_p):
|
187
|
+
logger.warning(
|
188
|
+
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
|
189
|
+
)
|
190
|
+
self.topk_p = self.topk_p[: len(new_indices)]
|
191
|
+
self.topk_index = self.topk_index[: len(new_indices)]
|
192
|
+
self.hidden_states = self.hidden_states[: len(new_indices)]
|
193
|
+
self.verified_id = self.verified_id[: len(new_indices)]
|
194
|
+
else:
|
195
|
+
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
|
196
|
+
self.topk_p = self.topk_p[new_indices]
|
197
|
+
self.topk_index = self.topk_index[new_indices]
|
198
|
+
self.hidden_states = self.hidden_states[new_indices]
|
199
|
+
self.verified_id = self.verified_id[new_indices]
|
185
200
|
|
186
201
|
def merge_batch(self, spec_info: EagleDraftInput):
|
187
202
|
if self.hidden_states is None:
|
@@ -410,8 +425,15 @@ class EagleVerifyInput:
|
|
410
425
|
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
411
426
|
)
|
412
427
|
|
413
|
-
# Sample tokens
|
414
|
-
|
428
|
+
# Sample tokens. Force greedy sampling on AMD
|
429
|
+
is_all_greedy = sampling_info.is_all_greedy
|
430
|
+
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
|
431
|
+
logger.warning(
|
432
|
+
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
|
433
|
+
"Falling back to greedy verification."
|
434
|
+
)
|
435
|
+
|
436
|
+
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
|
415
437
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
416
438
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
417
439
|
|
@@ -440,12 +462,13 @@ class EagleVerifyInput:
|
|
440
462
|
sampling_info.top_ks, self.draft_token_num, dim=0
|
441
463
|
),
|
442
464
|
) # (bs * draft_token_num, vocab_size)
|
443
|
-
|
444
|
-
target_probs
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
465
|
+
if not torch.all(sampling_info.top_ps == 1.0):
|
466
|
+
target_probs = top_p_renorm_prob(
|
467
|
+
target_probs,
|
468
|
+
torch.repeat_interleave(
|
469
|
+
sampling_info.top_ps, self.draft_token_num, dim=0
|
470
|
+
),
|
471
|
+
)
|
449
472
|
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
450
473
|
|
451
474
|
draft_probs = torch.zeros(
|
@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download
|
|
9
9
|
|
10
10
|
from sglang.srt.distributed import (
|
11
11
|
GroupCoordinator,
|
12
|
-
get_tensor_model_parallel_world_size,
|
13
12
|
get_tp_group,
|
14
13
|
patch_tensor_parallel_group,
|
15
14
|
)
|
@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker):
|
|
92
91
|
)
|
93
92
|
self.padded_static_len = -1
|
94
93
|
|
95
|
-
# Override context length
|
94
|
+
# Override the context length of the draft model to be the same as the target model.
|
96
95
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
97
96
|
|
98
97
|
# Do not capture cuda graph in `super().__init__()`
|
@@ -267,6 +266,43 @@ class EAGLEWorker(TpModelWorker):
|
|
267
266
|
self.topk,
|
268
267
|
self.speculative_num_steps,
|
269
268
|
)
|
269
|
+
elif self.server_args.attention_backend == "trtllm_mha":
|
270
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
271
|
+
TRTLLMHAAttnBackend,
|
272
|
+
TRTLLMHAAttnMultiStepDraftBackend,
|
273
|
+
)
|
274
|
+
|
275
|
+
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
276
|
+
self.draft_model_runner,
|
277
|
+
self.topk,
|
278
|
+
self.speculative_num_steps,
|
279
|
+
)
|
280
|
+
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
281
|
+
self.draft_model_runner,
|
282
|
+
skip_prefill=False,
|
283
|
+
)
|
284
|
+
self.has_prefill_wrapper_verify = True
|
285
|
+
elif self.server_args.attention_backend == "trtllm_mla":
|
286
|
+
if not global_server_args_dict["use_mla_backend"]:
|
287
|
+
raise ValueError(
|
288
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
289
|
+
)
|
290
|
+
|
291
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
292
|
+
TRTLLMMLABackend,
|
293
|
+
TRTLLMMLAMultiStepDraftBackend,
|
294
|
+
)
|
295
|
+
|
296
|
+
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
|
297
|
+
self.draft_model_runner,
|
298
|
+
self.topk,
|
299
|
+
self.speculative_num_steps,
|
300
|
+
)
|
301
|
+
self.draft_extend_attn_backend = TRTLLMMLABackend(
|
302
|
+
self.draft_model_runner,
|
303
|
+
skip_prefill=False,
|
304
|
+
)
|
305
|
+
self.has_prefill_wrapper_verify = True
|
270
306
|
else:
|
271
307
|
raise ValueError(
|
272
308
|
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
@@ -836,6 +872,21 @@ class EAGLEWorker(TpModelWorker):
|
|
836
872
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
837
873
|
assert forward_batch.spec_info is batch.spec_info
|
838
874
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
875
|
+
has_finished, unfinished_req_index = False, []
|
876
|
+
for i, req in enumerate(batch.reqs):
|
877
|
+
if req.finished():
|
878
|
+
has_finished = True
|
879
|
+
else:
|
880
|
+
unfinished_req_index.append(i)
|
881
|
+
if has_finished:
|
882
|
+
unfinished_index_device = torch.tensor(
|
883
|
+
unfinished_req_index,
|
884
|
+
dtype=torch.int64,
|
885
|
+
device=batch.spec_info.topk_p.device,
|
886
|
+
)
|
887
|
+
batch.spec_info.filter_batch(
|
888
|
+
unfinished_index_device, has_been_filtered=False
|
889
|
+
)
|
839
890
|
|
840
891
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
841
892
|
assert isinstance(batch.spec_info, EagleDraftInput)
|
@@ -966,7 +1017,9 @@ def get_last_loc_large_page_size_top_k_1(
|
|
966
1017
|
return prefix_lens, seq_lens, last_loc
|
967
1018
|
|
968
1019
|
|
969
|
-
|
1020
|
+
# Disable torch.compile for this function because it will be
|
1021
|
+
# even slower.
|
1022
|
+
# @torch.compile(dynamic=True)
|
970
1023
|
def get_last_loc_large_page_size_large_top_k(
|
971
1024
|
req_to_token: torch.Tensor,
|
972
1025
|
req_pool_indices: torch.Tensor,
|
@@ -0,0 +1,161 @@
|
|
1
|
+
import functools
|
2
|
+
import json
|
3
|
+
from typing import AbstractSet, Collection, List, Literal, Union
|
4
|
+
|
5
|
+
|
6
|
+
class TiktokenProcessor:
|
7
|
+
def __init__(self, name: str):
|
8
|
+
self.tokenizer = TiktokenTokenizer(name)
|
9
|
+
|
10
|
+
def image_processor(self, image):
|
11
|
+
return {"pixel_values": [image]}
|
12
|
+
|
13
|
+
|
14
|
+
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
|
15
|
+
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
|
16
|
+
|
17
|
+
|
18
|
+
PAD = "<|pad|>"
|
19
|
+
EOS = "<|eos|>"
|
20
|
+
SEP = "<|separator|>"
|
21
|
+
|
22
|
+
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
|
23
|
+
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
24
|
+
|
25
|
+
# default + separate each single digit
|
26
|
+
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
27
|
+
|
28
|
+
|
29
|
+
class TiktokenTokenizer:
|
30
|
+
def __init__(self, tokenizer_path):
|
31
|
+
import tiktoken
|
32
|
+
from jinja2 import Template
|
33
|
+
|
34
|
+
# Read the JSON
|
35
|
+
with open(tokenizer_path, "rb") as fin:
|
36
|
+
xtok_dict = json.load(fin)
|
37
|
+
|
38
|
+
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
|
39
|
+
mergeable_ranks = {
|
40
|
+
bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
|
41
|
+
}
|
42
|
+
special_tokens = {
|
43
|
+
bytes(item["bytes"]).decode(): item["token"]
|
44
|
+
for item in xtok_dict["special_tokens"]
|
45
|
+
}
|
46
|
+
if xtok_dict["word_split"] == "V1":
|
47
|
+
pad_str = PAT_STR_B
|
48
|
+
else:
|
49
|
+
assert False, f"Unknown word_split: {xtok_dict['word_split']}"
|
50
|
+
pad_str = xtok_dict.get("pat_str", pad_str)
|
51
|
+
|
52
|
+
kwargs = {
|
53
|
+
"name": tokenizer_path,
|
54
|
+
"pat_str": pad_str,
|
55
|
+
"mergeable_ranks": mergeable_ranks,
|
56
|
+
"special_tokens": special_tokens,
|
57
|
+
}
|
58
|
+
if "default_allowed_special" in xtok_dict:
|
59
|
+
default_allowed_special = set(
|
60
|
+
[
|
61
|
+
bytes(bytes_list).decode()
|
62
|
+
for bytes_list in xtok_dict["default_allowed_special"]
|
63
|
+
]
|
64
|
+
)
|
65
|
+
if "vocab_size" in xtok_dict:
|
66
|
+
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
|
67
|
+
|
68
|
+
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
|
69
|
+
default_allowed_special = None
|
70
|
+
control_tokens = DEFAULT_CONTROL_TOKENS
|
71
|
+
tokenizer = tiktoken.Encoding(**kwargs)
|
72
|
+
tokenizer._default_allowed_special = default_allowed_special or set()
|
73
|
+
tokenizer._control_tokens = control_tokens
|
74
|
+
|
75
|
+
def encode_patched(
|
76
|
+
self,
|
77
|
+
text: str,
|
78
|
+
*,
|
79
|
+
allowed_special: Union[
|
80
|
+
Literal["all"], AbstractSet[str]
|
81
|
+
] = set(), # noqa: B006
|
82
|
+
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
83
|
+
) -> List[int]:
|
84
|
+
if isinstance(allowed_special, set):
|
85
|
+
allowed_special |= self._default_allowed_special
|
86
|
+
return tiktoken.Encoding.encode(
|
87
|
+
self,
|
88
|
+
text,
|
89
|
+
allowed_special=allowed_special,
|
90
|
+
disallowed_special=(),
|
91
|
+
)
|
92
|
+
|
93
|
+
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
94
|
+
|
95
|
+
# Allow more tokens to prevent crash
|
96
|
+
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
|
97
|
+
tokenizer._default_allowed_special |= set(
|
98
|
+
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
|
99
|
+
)
|
100
|
+
|
101
|
+
# Convert to HF interface
|
102
|
+
self.tokenizer = tokenizer
|
103
|
+
self.bos_token_id = None
|
104
|
+
self.eos_token_id = tokenizer._special_tokens[EOS]
|
105
|
+
self.vocab_size = tokenizer.n_vocab
|
106
|
+
self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
107
|
+
self.chat_template_jinja = Template(self.chat_template)
|
108
|
+
self.additional_stop_token_ids = None
|
109
|
+
|
110
|
+
def encode(self, x, add_special_tokens=False):
|
111
|
+
return self.tokenizer.encode(x)
|
112
|
+
|
113
|
+
def decode(self, x, *args, **kwargs):
|
114
|
+
return self.tokenizer.decode(x)
|
115
|
+
|
116
|
+
def batch_decode(
|
117
|
+
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
118
|
+
):
|
119
|
+
if len(batch) > 0 and isinstance(batch[0], int):
|
120
|
+
batch = [[x] for x in batch]
|
121
|
+
return self.tokenizer.decode_batch(batch)
|
122
|
+
|
123
|
+
def apply_chat_template(
|
124
|
+
self, messages, tokenize, add_generation_prompt, tools=None
|
125
|
+
):
|
126
|
+
ret = self.chat_template_jinja.render(
|
127
|
+
messages=messages, add_generation_prompt=add_generation_prompt
|
128
|
+
)
|
129
|
+
return self.encode(ret) if tokenize else ret
|
130
|
+
|
131
|
+
def __call__(self, text, **kwargs):
|
132
|
+
return {
|
133
|
+
"input_ids": self.encode(text),
|
134
|
+
}
|
135
|
+
|
136
|
+
def init_xgrammar(self):
|
137
|
+
from xgrammar import TokenizerInfo
|
138
|
+
|
139
|
+
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
|
140
|
+
|
141
|
+
enc = self.tokenizer
|
142
|
+
encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
|
143
|
+
encoded_vocab = [
|
144
|
+
token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
|
145
|
+
]
|
146
|
+
override_stop_tokens = [2] # eos
|
147
|
+
# These are treated as special tokens in xgrammar; we want to avoid them
|
148
|
+
# For now, xgrammar treats anything starting with b'\x00' as a special token
|
149
|
+
xgrammar_special_token_ids = []
|
150
|
+
for i, token in enumerate(encoded_vocab):
|
151
|
+
if isinstance(token, bytes) and token.startswith(b"\x00"):
|
152
|
+
xgrammar_special_token_ids.append(i)
|
153
|
+
|
154
|
+
for i, id in enumerate(xgrammar_special_token_ids):
|
155
|
+
encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
|
156
|
+
tokenizer_info = TokenizerInfo(
|
157
|
+
encoded_vocab, stop_token_ids=override_stop_tokens
|
158
|
+
)
|
159
|
+
assert len(tokenizer_info.special_token_ids) == 0
|
160
|
+
|
161
|
+
return tokenizer_info, override_stop_tokens
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
|
|
14
14
|
CommunicateSummableTensorPairFn,
|
15
15
|
ScatterMode,
|
16
16
|
)
|
17
|
+
from sglang.srt.layers.moe import (
|
18
|
+
get_deepep_mode,
|
19
|
+
get_moe_a2a_backend,
|
20
|
+
get_tbo_token_distribution_threshold,
|
21
|
+
is_tbo_enabled,
|
22
|
+
)
|
17
23
|
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
18
|
-
from sglang.srt.layers.moe.utils import DeepEPMode
|
19
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
20
25
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
21
26
|
from sglang.srt.model_executor.forward_batch_info import (
|
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
|
83
88
|
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
84
89
|
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
85
90
|
overall_sum = sum(extend_lens)
|
86
|
-
threshold =
|
91
|
+
threshold = get_tbo_token_distribution_threshold()
|
87
92
|
assert threshold <= 0.5, f"{threshold=}"
|
88
93
|
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
89
94
|
1 - threshold
|
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
|
|
299
304
|
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
300
305
|
|
301
306
|
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
302
|
-
if not
|
307
|
+
if not is_tbo_enabled():
|
303
308
|
return
|
304
309
|
token_num_per_seq = get_token_num_per_seq(
|
305
310
|
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
|
|
353
358
|
def prepare_all_gather(
|
354
359
|
self,
|
355
360
|
local_batch: ScheduleBatch,
|
356
|
-
deepep_mode: DeepEPMode,
|
357
|
-
enable_deepep_moe: bool,
|
358
|
-
enable_two_batch_overlap: bool,
|
359
361
|
):
|
362
|
+
|
363
|
+
deepep_mode = get_deepep_mode()
|
364
|
+
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
365
|
+
enable_two_batch_overlap = is_tbo_enabled()
|
366
|
+
|
360
367
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
361
368
|
|
362
369
|
if local_batch is not None:
|
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
|
|
384
391
|
and not local_batch.forward_mode.is_target_verify()
|
385
392
|
)
|
386
393
|
and enable_deepep_moe
|
387
|
-
and (resolved_deepep_mode
|
394
|
+
and (resolved_deepep_mode.is_low_latency())
|
388
395
|
)
|
389
396
|
else:
|
390
397
|
self.local_tbo_split_seq_index = 0
|
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
|
|
657
664
|
"req_to_token_pool",
|
658
665
|
"token_to_kv_pool",
|
659
666
|
"can_run_dp_cuda_graph",
|
667
|
+
"dp_padding_mode",
|
660
668
|
"global_forward_mode",
|
661
669
|
"spec_algorithm",
|
662
670
|
"capture_hidden_mode",
|
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
|
|
701
709
|
tbo_children=None,
|
702
710
|
global_num_tokens_gpu=None,
|
703
711
|
global_num_tokens_cpu=None,
|
704
|
-
dp_padding_mode=None,
|
705
712
|
global_dp_buffer_len=global_dp_buffer_len,
|
706
713
|
global_num_tokens_for_logprob_gpu=None,
|
707
714
|
global_num_tokens_for_logprob_cpu=None,
|
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
|
|
955
962
|
|
956
963
|
class MaybeTboDeepEPDispatcher:
|
957
964
|
def __init__(self, **kwargs):
|
958
|
-
num_inner_dispatchers = (
|
959
|
-
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
|
960
|
-
)
|
965
|
+
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
|
961
966
|
self._inners = [
|
962
967
|
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
963
968
|
]
|
sglang/srt/utils.py
CHANGED
@@ -438,70 +438,6 @@ def is_pin_memory_available() -> bool:
|
|
438
438
|
return torch.cuda.is_available()
|
439
439
|
|
440
440
|
|
441
|
-
_CPU_OFFLOAD_BYTES = 0
|
442
|
-
_CPU_OFFLOAD_MAX_BYTES = 0
|
443
|
-
|
444
|
-
|
445
|
-
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
446
|
-
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
447
|
-
_CPU_OFFLOAD_BYTES = 0
|
448
|
-
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
449
|
-
|
450
|
-
|
451
|
-
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
452
|
-
device = next(module.parameters()).device
|
453
|
-
|
454
|
-
if device == torch.device("cpu"):
|
455
|
-
return module
|
456
|
-
|
457
|
-
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
458
|
-
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
459
|
-
return module
|
460
|
-
|
461
|
-
pin_memory = is_pin_memory_available()
|
462
|
-
# offload parameters to CPU
|
463
|
-
# use pin_memory if possible, which helps cudagraph capture speed
|
464
|
-
offloaded_parameters = False
|
465
|
-
for p in module.parameters():
|
466
|
-
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
467
|
-
# we use per-parameter offloading
|
468
|
-
# one module might have some parameters offloaded and some not
|
469
|
-
break
|
470
|
-
|
471
|
-
# `torch.empty_like` does not support `pin_memory` argument
|
472
|
-
cpu_data = torch.empty_strided(
|
473
|
-
size=p.data.size(),
|
474
|
-
stride=p.data.stride(),
|
475
|
-
dtype=p.data.dtype,
|
476
|
-
layout=p.data.layout,
|
477
|
-
device="cpu",
|
478
|
-
pin_memory=pin_memory,
|
479
|
-
)
|
480
|
-
cpu_data.copy_(p.data)
|
481
|
-
p.data = cpu_data
|
482
|
-
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
483
|
-
offloaded_parameters = True
|
484
|
-
|
485
|
-
if offloaded_parameters:
|
486
|
-
original_forward = module.forward
|
487
|
-
|
488
|
-
def forward(*args, **kwargs):
|
489
|
-
module.forward = original_forward
|
490
|
-
device_state = {
|
491
|
-
# here we blindly call `to(device)`
|
492
|
-
# if the parameter is already on the device, it will be a no-op
|
493
|
-
k: v.to(device, non_blocking=True)
|
494
|
-
for k, v in module.state_dict().items()
|
495
|
-
}
|
496
|
-
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
497
|
-
module.forward = forward
|
498
|
-
return output
|
499
|
-
|
500
|
-
module.forward = forward
|
501
|
-
|
502
|
-
return module
|
503
|
-
|
504
|
-
|
505
441
|
class LayerFn(Protocol):
|
506
442
|
|
507
443
|
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
@@ -514,11 +450,13 @@ def make_layers(
|
|
514
450
|
pp_size: Optional[int] = None,
|
515
451
|
prefix: str = "",
|
516
452
|
return_tuple: bool = False,
|
453
|
+
offloader_kwargs: Dict[str, Any] = {},
|
517
454
|
) -> Tuple[int, int, torch.nn.ModuleList]:
|
518
455
|
"""Make a list of layers with the given layer function"""
|
519
456
|
# circula imports
|
520
457
|
from sglang.srt.distributed import get_pp_indices
|
521
458
|
from sglang.srt.layers.utils import PPMissingLayer
|
459
|
+
from sglang.srt.offloader import get_offloader
|
522
460
|
|
523
461
|
assert not pp_size or num_hidden_layers >= pp_size
|
524
462
|
start_layer, end_layer = (
|
@@ -532,10 +470,13 @@ def make_layers(
|
|
532
470
|
)
|
533
471
|
modules = torch.nn.ModuleList(
|
534
472
|
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
|
535
|
-
+
|
536
|
-
|
537
|
-
|
538
|
-
|
473
|
+
+ get_offloader().wrap_modules(
|
474
|
+
(
|
475
|
+
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
|
476
|
+
for idx in range(start_layer, end_layer)
|
477
|
+
),
|
478
|
+
**offloader_kwargs,
|
479
|
+
)
|
539
480
|
+ [
|
540
481
|
PPMissingLayer(return_tuple=return_tuple)
|
541
482
|
for _ in range(end_layer, num_hidden_layers)
|
@@ -2343,6 +2284,7 @@ def is_fa3_default_architecture(hf_config):
|
|
2343
2284
|
"Qwen3ForCausalLM",
|
2344
2285
|
"Qwen3MoeForCausalLM",
|
2345
2286
|
"Glm4MoeForCausalLM",
|
2287
|
+
"Glm4vMoeForConditionalGeneration",
|
2346
2288
|
"Step3VLForConditionalGeneration",
|
2347
2289
|
}
|
2348
2290
|
return architectures[0] in default_archs
|
@@ -2413,7 +2355,7 @@ def require_mlp_tp_gather(server_args):
|
|
2413
2355
|
return True
|
2414
2356
|
elif not server_args.enable_dp_lm_head:
|
2415
2357
|
return True
|
2416
|
-
elif server_args.moe_a2a_backend
|
2358
|
+
elif server_args.moe_a2a_backend == "none":
|
2417
2359
|
return True
|
2418
2360
|
else:
|
2419
2361
|
return (
|
@@ -2429,7 +2371,7 @@ def require_attn_tp_gather(server_args):
|
|
2429
2371
|
Check if the input of attention is scattered.
|
2430
2372
|
"""
|
2431
2373
|
assert server_args.moe_dense_tp_size in [1, None]
|
2432
|
-
if server_args.moe_a2a_backend
|
2374
|
+
if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
|
2433
2375
|
if server_args.enable_dp_attention:
|
2434
2376
|
return server_args.dp_size < server_args.tp_size
|
2435
2377
|
else:
|
@@ -2599,6 +2541,50 @@ def dynamic_import(func_path: str):
|
|
2599
2541
|
return func
|
2600
2542
|
|
2601
2543
|
|
2544
|
+
def gc_object_counts():
|
2545
|
+
import gc
|
2546
|
+
|
2547
|
+
g0 = len(gc.get_objects(0))
|
2548
|
+
g1 = len(gc.get_objects(1))
|
2549
|
+
g2 = len(gc.get_objects(2))
|
2550
|
+
return g0, g1, g2
|
2551
|
+
|
2552
|
+
|
2553
|
+
def configure_gc_warning(warn_threshold_secs):
|
2554
|
+
import gc
|
2555
|
+
|
2556
|
+
gc_start_time = {}
|
2557
|
+
|
2558
|
+
def gc_callback(phase, info):
|
2559
|
+
gen = info.get("generation", "?")
|
2560
|
+
if phase == "start":
|
2561
|
+
gc_start_time[gen] = time.time()
|
2562
|
+
elif phase == "stop":
|
2563
|
+
duration = time.time() - gc_start_time.get(gen, time.time())
|
2564
|
+
if duration > warn_threshold_secs:
|
2565
|
+
g0, g1, g2 = gc_object_counts()
|
2566
|
+
logger.warn(
|
2567
|
+
f"LONG GARBAGE COLLECTION DETECTED | Generation {gen} | Duration: {duration:.4f}s | # Objects: gen0={g0}, gen1={g1}, gen2={g2} | "
|
2568
|
+
f"This may cause latency jitter. Consider calling the freeze_gc API after sending a few warmup requests."
|
2569
|
+
)
|
2570
|
+
|
2571
|
+
gc.callbacks.append(gc_callback)
|
2572
|
+
|
2573
|
+
|
2574
|
+
def freeze_gc(context: str):
|
2575
|
+
import gc
|
2576
|
+
|
2577
|
+
g0_before, g1_before, g2_before = gc_object_counts()
|
2578
|
+
gc.freeze()
|
2579
|
+
g0_after, g1_after, g2_after = gc_object_counts()
|
2580
|
+
logger.info(
|
2581
|
+
f"Freezing GC in {context} process. "
|
2582
|
+
f"gen0: {g0_before}->{g0_after}, "
|
2583
|
+
f"gen1: {g1_before}->{g1_after}, "
|
2584
|
+
f"gen2: {g2_before}->{g2_after}"
|
2585
|
+
)
|
2586
|
+
|
2587
|
+
|
2602
2588
|
def configure_gc_logger():
|
2603
2589
|
logger.info("Enable GC Logger")
|
2604
2590
|
|
@@ -2872,6 +2858,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
|
2872
2858
|
"gate_proj",
|
2873
2859
|
"up_proj",
|
2874
2860
|
"down_proj",
|
2861
|
+
"qkv_proj",
|
2862
|
+
"gate_up_proj",
|
2875
2863
|
]
|
2876
2864
|
|
2877
2865
|
LORA_TARGET_ALL_MODULES = "all"
|
@@ -2966,3 +2954,13 @@ class ConcurrentCounter:
|
|
2966
2954
|
@lru_cache(maxsize=1)
|
2967
2955
|
def is_triton_kernels_available() -> bool:
|
2968
2956
|
return importlib.util.find_spec("triton_kernels") is not None
|
2957
|
+
|
2958
|
+
|
2959
|
+
def check_cuda_result(raw_output):
|
2960
|
+
import cuda.bindings.runtime as cuda_rt
|
2961
|
+
|
2962
|
+
err, *results = raw_output
|
2963
|
+
if err != cuda_rt.cudaError_t.cudaSuccess:
|
2964
|
+
raise Exception(f"CUDA error: {err}")
|
2965
|
+
|
2966
|
+
return results
|
sglang/test/runners.py
CHANGED
@@ -231,11 +231,14 @@ class HFRunner:
|
|
231
231
|
|
232
232
|
# Load the model and tokenizer
|
233
233
|
if self.model_type == "generation":
|
234
|
-
config = AutoConfig.from_pretrained(
|
235
|
-
|
236
|
-
|
237
|
-
|
234
|
+
config = AutoConfig.from_pretrained(
|
235
|
+
model_path, trust_remote_code=self.trust_remote_code
|
236
|
+
)
|
237
|
+
if self.trust_remote_code:
|
238
238
|
model_cls = AutoModelForCausalLM
|
239
|
+
else:
|
240
|
+
model_arch = getattr(config, "architectures")[0]
|
241
|
+
model_cls = getattr(transformers, model_arch)
|
239
242
|
self.base_model = model_cls.from_pretrained(
|
240
243
|
model_path,
|
241
244
|
torch_dtype=torch_dtype,
|
@@ -488,7 +491,7 @@ class SRTRunner:
|
|
488
491
|
tp_size: int = 1,
|
489
492
|
model_impl: str = "auto",
|
490
493
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
491
|
-
lora_paths: List[str] = None,
|
494
|
+
lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
|
492
495
|
max_loras_per_batch: int = 4,
|
493
496
|
attention_backend: Optional[str] = None,
|
494
497
|
prefill_attention_backend: Optional[str] = None,
|
sglang/test/test_block_fp8.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
|
7
7
|
from sglang.srt.layers.activation import SiluAndMul
|
8
8
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
9
|
-
from sglang.srt.layers.moe.topk import select_experts
|
9
|
+
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
10
10
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
11
11
|
per_tensor_quant_mla_fp8,
|
12
12
|
per_token_group_quant_fp8,
|
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|
498
498
|
score = torch.randn((M, E), dtype=dtype)
|
499
499
|
|
500
500
|
with torch.inference_mode():
|
501
|
+
ref_out = torch_w8a8_block_fp8_moe(
|
502
|
+
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
503
|
+
)
|
501
504
|
topk_output = select_experts(
|
502
505
|
hidden_states=a,
|
503
506
|
router_logits=score,
|
504
|
-
top_k=topk,
|
505
|
-
renormalize=False,
|
507
|
+
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
506
508
|
)
|
507
509
|
out = fused_moe(
|
508
510
|
a,
|
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|
514
516
|
w2_scale=w2_s,
|
515
517
|
block_shape=block_size,
|
516
518
|
)
|
517
|
-
ref_out = torch_w8a8_block_fp8_moe(
|
518
|
-
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
519
|
-
)
|
520
519
|
|
521
520
|
self.assertTrue(
|
522
521
|
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|