sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
|
|
297
297
|
|
298
298
|
def forward_batch_speculative_generation(
|
299
299
|
self, batch: ScheduleBatch
|
300
|
-
) -> Tuple[LogitsProcessorOutput,
|
300
|
+
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
301
301
|
"""Run speculative decoding forward.
|
302
302
|
|
303
303
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
|
|
325
325
|
self.verify(batch, spec_info)
|
326
326
|
)
|
327
327
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
328
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
329
|
+
# NOTE: We should use `check_forward_draft_extend_after_decode`
|
330
|
+
# when DP attention is enabled, but it is slow. Skip it for now.
|
331
|
+
if (
|
332
|
+
self.server_args.enable_dp_attention
|
333
|
+
or batch.spec_info.verified_id.shape[0] > 0
|
334
|
+
):
|
335
|
+
# decode is not finished
|
336
|
+
self.forward_draft_extend_after_decode(batch)
|
337
|
+
|
333
338
|
return (
|
334
339
|
logits_output,
|
335
340
|
verify_output.verified_id,
|
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
|
|
339
344
|
)
|
340
345
|
|
341
346
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
342
|
-
local_need_forward =
|
343
|
-
batch.spec_info.verified_id is not None
|
344
|
-
and batch.spec_info.verified_id.shape[0] > 0
|
345
|
-
)
|
347
|
+
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
|
346
348
|
if not self.server_args.enable_dp_attention:
|
347
349
|
return local_need_forward
|
348
350
|
|
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
|
|
361
363
|
|
362
364
|
def forward_target_extend(
|
363
365
|
self, batch: ScheduleBatch
|
364
|
-
) -> Tuple[LogitsProcessorOutput,
|
366
|
+
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
|
365
367
|
"""Run the target extend.
|
366
368
|
|
367
369
|
Args:
|
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
|
|
376
378
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
377
379
|
model_worker_batch = batch.get_model_worker_batch()
|
378
380
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
379
|
-
model_worker_batch.spec_num_draft_tokens = 1
|
380
381
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
381
382
|
model_worker_batch
|
382
383
|
)
|
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
|
|
508
509
|
self._draft_preprocess_decode(batch)
|
509
510
|
|
510
511
|
spec_info = batch.spec_info
|
512
|
+
assert isinstance(spec_info, EagleDraftInput)
|
511
513
|
|
512
514
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
515
|
+
spec_info.num_tokens_per_batch = self.topk
|
516
|
+
spec_info.num_tokens_for_logprob_per_batch = self.topk
|
513
517
|
batch.return_hidden_states = False
|
514
518
|
|
515
519
|
# Get forward batch
|
516
520
|
model_worker_batch = batch.get_model_worker_batch()
|
517
|
-
model_worker_batch.spec_num_draft_tokens = self.topk
|
518
521
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
519
522
|
forward_batch = ForwardBatch.init_new(
|
520
523
|
model_worker_batch, self.draft_model_runner
|
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
|
|
527
530
|
forward_batch
|
528
531
|
)
|
529
532
|
else:
|
533
|
+
forward_batch.can_run_dp_cuda_graph = False
|
530
534
|
if not forward_batch.forward_mode.is_idle():
|
531
535
|
# Initialize attention backend
|
532
536
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
|
|
578
582
|
def draft_forward(self, forward_batch: ForwardBatch):
|
579
583
|
# Parse args
|
580
584
|
spec_info = forward_batch.spec_info
|
585
|
+
assert isinstance(spec_info, EagleDraftInput)
|
581
586
|
out_cache_loc = forward_batch.out_cache_loc
|
582
587
|
topk_p, topk_index, hidden_states = (
|
583
588
|
spec_info.topk_p,
|
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
|
|
621
626
|
spec_info.hidden_states = hidden_states
|
622
627
|
|
623
628
|
# Run forward
|
624
|
-
logits_output = self.draft_model_runner.
|
625
|
-
forward_batch
|
629
|
+
logits_output, _ = self.draft_model_runner.forward(
|
630
|
+
forward_batch, skip_attn_backend_init=True
|
626
631
|
)
|
627
632
|
self._detect_nan_if_needed(logits_output)
|
628
633
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
|
|
642
647
|
else ForwardMode.IDLE
|
643
648
|
)
|
644
649
|
batch.spec_info = spec_info
|
650
|
+
|
645
651
|
model_worker_batch = batch.get_model_worker_batch(
|
646
652
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
647
653
|
)
|
648
|
-
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
649
654
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
650
655
|
|
651
656
|
if batch.has_grammar:
|
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
|
|
782
787
|
self,
|
783
788
|
batch: ScheduleBatch,
|
784
789
|
hidden_states: torch.Tensor,
|
785
|
-
next_token_ids:
|
786
|
-
seq_lens_cpu: torch.Tensor,
|
790
|
+
next_token_ids: torch.Tensor,
|
791
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
787
792
|
):
|
788
793
|
"""Run draft model extend. This API modifies the states of the batch.
|
789
794
|
|
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
|
|
795
800
|
batch.spec_info = EagleDraftInput(
|
796
801
|
hidden_states=hidden_states,
|
797
802
|
verified_id=next_token_ids,
|
803
|
+
num_tokens_per_batch=1,
|
804
|
+
num_tokens_for_logprob_per_batch=1,
|
798
805
|
)
|
799
806
|
batch.return_hidden_states = False
|
800
807
|
batch.spec_info.prepare_for_extend(batch)
|
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
|
|
802
809
|
model_worker_batch = batch.get_model_worker_batch(
|
803
810
|
seq_lens_cpu_cache=seq_lens_cpu
|
804
811
|
)
|
805
|
-
model_worker_batch.spec_num_draft_tokens = 1
|
806
812
|
forward_batch = ForwardBatch.init_new(
|
807
813
|
model_worker_batch, self.draft_model_runner
|
808
814
|
)
|
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
|
|
814
820
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
815
821
|
|
816
822
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
823
|
+
assert isinstance(batch.spec_info, EagleDraftInput)
|
817
824
|
# Backup fields that will be modified in-place
|
818
825
|
seq_lens_backup = batch.seq_lens.clone()
|
819
826
|
req_pool_indices_backup = batch.req_pool_indices
|
820
827
|
accept_length_backup = batch.spec_info.accept_length
|
821
828
|
return_logprob_backup = batch.return_logprob
|
829
|
+
|
822
830
|
input_is_idle = batch.forward_mode.is_idle()
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
)
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
831
|
+
|
832
|
+
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
|
833
|
+
batch = batch.copy()
|
834
|
+
batch.prepare_for_idle()
|
835
|
+
hidden_size = (
|
836
|
+
self.model_config.hidden_size * 3
|
837
|
+
if self.speculative_algorithm.is_eagle3()
|
838
|
+
else self.model_config.hidden_size
|
839
|
+
)
|
840
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
841
|
+
device=self.device,
|
842
|
+
hidden_size=hidden_size,
|
843
|
+
dtype=self.model_config.dtype,
|
844
|
+
topk=self.topk,
|
845
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
846
|
+
)
|
847
|
+
|
848
|
+
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
|
849
|
+
batch.spec_info.num_tokens_for_logprob_per_batch = 1
|
850
|
+
batch.spec_info.prepare_extend_after_decode(
|
851
|
+
batch,
|
852
|
+
self.speculative_num_steps,
|
853
|
+
)
|
854
|
+
batch.forward_mode = (
|
855
|
+
ForwardMode.DRAFT_EXTEND
|
856
|
+
if not batch.forward_mode.is_idle()
|
857
|
+
else ForwardMode.IDLE
|
858
|
+
)
|
859
|
+
|
845
860
|
batch.return_hidden_states = False
|
846
861
|
model_worker_batch = batch.get_model_worker_batch()
|
847
|
-
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
848
862
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
849
863
|
forward_batch = ForwardBatch.init_new(
|
850
864
|
model_worker_batch, self.draft_model_runner
|
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
|
|
869
883
|
)
|
870
884
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
871
885
|
else:
|
886
|
+
forward_batch.can_run_dp_cuda_graph = False
|
872
887
|
if not forward_batch.forward_mode.is_idle():
|
873
888
|
self.draft_model_runner.attn_backend.init_forward_metadata(
|
874
889
|
forward_batch
|
875
890
|
)
|
876
|
-
logits_output = self.draft_model_runner.
|
877
|
-
forward_batch
|
891
|
+
logits_output, _ = self.draft_model_runner.forward(
|
892
|
+
forward_batch, skip_attn_backend_init=True
|
878
893
|
)
|
879
894
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
880
895
|
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -341,15 +341,18 @@ class TboDPAttentionPreparer:
|
|
341
341
|
|
342
342
|
@staticmethod
|
343
343
|
def _compute_global_forward_mode(forward_modes):
|
344
|
-
|
345
|
-
|
346
|
-
for x in forward_modes
|
344
|
+
forward_modes_excluding_idle = [
|
345
|
+
x for x in forward_modes if x != ForwardMode.IDLE.value
|
347
346
|
]
|
347
|
+
|
348
|
+
if not forward_modes_excluding_idle:
|
349
|
+
return ForwardMode.IDLE, False
|
350
|
+
|
348
351
|
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
|
349
|
-
|
352
|
+
forward_modes_excluding_idle
|
350
353
|
)
|
351
354
|
global_forward_mode = (
|
352
|
-
ForwardMode(
|
355
|
+
ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
|
353
356
|
)
|
354
357
|
return global_forward_mode, forward_mode_agree
|
355
358
|
|
@@ -500,6 +503,7 @@ class TboForwardBatchPreparer:
|
|
500
503
|
"capture_hidden_mode",
|
501
504
|
"padded_static_len",
|
502
505
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
506
|
+
"split_index", # for split prefill
|
503
507
|
]:
|
504
508
|
output_dict[key] = getattr(batch, key)
|
505
509
|
if not batch.forward_mode.is_target_verify():
|
@@ -541,6 +545,7 @@ class TboForwardBatchPreparer:
|
|
541
545
|
tbo_children=None,
|
542
546
|
global_num_tokens_gpu=None,
|
543
547
|
global_num_tokens_cpu=None,
|
548
|
+
dp_padding_mode=None,
|
544
549
|
gathered_buffer=gathered_buffer,
|
545
550
|
global_num_tokens_for_logprob_gpu=None,
|
546
551
|
global_num_tokens_for_logprob_cpu=None,
|
sglang/srt/utils.py
CHANGED
@@ -691,12 +691,17 @@ def decode_video_base64(video_base64):
|
|
691
691
|
) # Return an empty array and size tuple if no frames were found
|
692
692
|
|
693
693
|
|
694
|
-
def load_audio(
|
694
|
+
def load_audio(
|
695
|
+
audio_file: str, sr: Optional[int] = None, mono: bool = True
|
696
|
+
) -> np.ndarray:
|
695
697
|
# Use soundfile here, since librosa use it under the hood,
|
696
698
|
# and librosa will not support audio loading in the future
|
697
699
|
import soundfile as sf
|
698
700
|
from scipy.signal import resample
|
699
701
|
|
702
|
+
if sr is None:
|
703
|
+
sr = 16000
|
704
|
+
|
700
705
|
# Load audio data
|
701
706
|
if isinstance(audio_file, bytes):
|
702
707
|
audio, original_sr = sf.read(BytesIO(audio_file))
|
@@ -739,9 +744,13 @@ def load_image(
|
|
739
744
|
image = Image.open(BytesIO(image_file))
|
740
745
|
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
741
746
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
742
|
-
response = requests.get(image_file, stream=True, timeout=timeout)
|
743
|
-
|
744
|
-
|
747
|
+
response = requests.get(image_file, stream=True, timeout=timeout)
|
748
|
+
try:
|
749
|
+
response.raise_for_status()
|
750
|
+
image = Image.open(response.raw)
|
751
|
+
image.load() # Force loading to avoid issues after closing the stream
|
752
|
+
finally:
|
753
|
+
response.close()
|
745
754
|
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
746
755
|
image = Image.open(image_file)
|
747
756
|
elif image_file.startswith("data:"):
|
@@ -928,71 +937,6 @@ def monkey_patch_vllm_gguf_config():
|
|
928
937
|
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
|
929
938
|
|
930
939
|
|
931
|
-
def maybe_set_triton_cache_manager() -> None:
|
932
|
-
"""Set environment variable to tell Triton to use a
|
933
|
-
custom cache manager"""
|
934
|
-
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
935
|
-
if cache_manger is None:
|
936
|
-
manager = "sglang.srt.utils:CustomCacheManager"
|
937
|
-
logger.debug("Setting Triton cache manager to: %s", manager)
|
938
|
-
os.environ["TRITON_CACHE_MANAGER"] = manager
|
939
|
-
|
940
|
-
|
941
|
-
class CustomCacheManager(FileCacheManager):
|
942
|
-
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
|
943
|
-
def __init__(self, key, override=False, dump=False):
|
944
|
-
from sglang.srt.distributed.parallel_state import get_tp_group
|
945
|
-
|
946
|
-
self.key = key
|
947
|
-
self.lock_path = None
|
948
|
-
|
949
|
-
try:
|
950
|
-
module_path = "triton.runtime.cache"
|
951
|
-
cache_module = importlib.import_module(module_path)
|
952
|
-
|
953
|
-
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
954
|
-
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
955
|
-
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
956
|
-
except (ModuleNotFoundError, AttributeError) as e:
|
957
|
-
default_cache_dir = None
|
958
|
-
default_dump_dir = None
|
959
|
-
default_override_dir = None
|
960
|
-
|
961
|
-
if dump:
|
962
|
-
self.cache_dir = (
|
963
|
-
default_dump_dir()
|
964
|
-
if default_dump_dir is not None
|
965
|
-
else os.path.join(Path.home(), ".triton", "dump")
|
966
|
-
)
|
967
|
-
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
968
|
-
self.lock_path = os.path.join(self.cache_dir, "lock")
|
969
|
-
os.makedirs(self.cache_dir, exist_ok=True)
|
970
|
-
elif override:
|
971
|
-
self.cache_dir = (
|
972
|
-
default_override_dir()
|
973
|
-
if default_override_dir is not None
|
974
|
-
else os.path.join(Path.home(), ".triton", "override")
|
975
|
-
)
|
976
|
-
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
977
|
-
else:
|
978
|
-
# create cache directory if it doesn't exist
|
979
|
-
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
980
|
-
default_cache_dir()
|
981
|
-
if default_cache_dir is not None
|
982
|
-
else os.path.join(Path.home(), ".triton", "cache")
|
983
|
-
)
|
984
|
-
if self.cache_dir:
|
985
|
-
try:
|
986
|
-
self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
|
987
|
-
except:
|
988
|
-
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
989
|
-
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
990
|
-
self.lock_path = os.path.join(self.cache_dir, "lock")
|
991
|
-
os.makedirs(self.cache_dir, exist_ok=True)
|
992
|
-
else:
|
993
|
-
raise RuntimeError("Could not create or locate cache dir")
|
994
|
-
|
995
|
-
|
996
940
|
def set_ulimit(target_soft_limit=65535):
|
997
941
|
# number of open files
|
998
942
|
resource_type = resource.RLIMIT_NOFILE
|
@@ -1417,6 +1361,13 @@ def get_nvgpu_memory_capacity():
|
|
1417
1361
|
]
|
1418
1362
|
|
1419
1363
|
if not memory_values:
|
1364
|
+
# Fallback to torch.cuda.mem_get_info() when failed to get memory capacity from nvidia-smi,
|
1365
|
+
# typically in NVIDIA MIG mode.
|
1366
|
+
if torch.cuda.is_available():
|
1367
|
+
logger.warning(
|
1368
|
+
"Failed to get GPU memory capacity from nvidia-smi, falling back to torch.cuda.mem_get_info()."
|
1369
|
+
)
|
1370
|
+
return torch.cuda.mem_get_info()[1] // 1024 // 1024 # unit: MB
|
1420
1371
|
raise ValueError("No GPU memory values found.")
|
1421
1372
|
|
1422
1373
|
# Return the minimum memory value
|
@@ -2049,6 +2000,16 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|
2049
2000
|
return False
|
2050
2001
|
|
2051
2002
|
|
2003
|
+
def maybe_wrap_ipv6_address(address: str) -> str:
|
2004
|
+
if is_valid_ipv6_address(address):
|
2005
|
+
return f"[{address}]"
|
2006
|
+
return address
|
2007
|
+
|
2008
|
+
|
2009
|
+
def format_tcp_address(ip: str, port: int) -> str:
|
2010
|
+
return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
|
2011
|
+
|
2012
|
+
|
2052
2013
|
def configure_ipv6(dist_init_addr):
|
2053
2014
|
addr = dist_init_addr
|
2054
2015
|
end = addr.find("]")
|
@@ -2880,3 +2841,17 @@ def parse_module_path(module_path, function_name, create_dummy):
|
|
2880
2841
|
return final_module, getattr(final_module, function_name)
|
2881
2842
|
|
2882
2843
|
return final_module, None
|
2844
|
+
|
2845
|
+
|
2846
|
+
# LoRA-related constants and utilities
|
2847
|
+
SUPPORTED_LORA_TARGET_MODULES = [
|
2848
|
+
"q_proj",
|
2849
|
+
"k_proj",
|
2850
|
+
"v_proj",
|
2851
|
+
"o_proj",
|
2852
|
+
"gate_proj",
|
2853
|
+
"up_proj",
|
2854
|
+
"down_proj",
|
2855
|
+
]
|
2856
|
+
|
2857
|
+
LORA_TARGET_ALL_MODULES = "all"
|
sglang/test/runners.py
CHANGED
@@ -134,10 +134,12 @@ class HFRunner:
|
|
134
134
|
model_type: str = "generation",
|
135
135
|
output_str_only: bool = False,
|
136
136
|
trust_remote_code: bool = False,
|
137
|
+
patch_model_do_sample_false: bool = False,
|
137
138
|
):
|
138
139
|
self.model_type = model_type
|
139
140
|
self.output_str_only = output_str_only
|
140
141
|
self.trust_remote_code = trust_remote_code
|
142
|
+
self.patch_model_do_sample_false = patch_model_do_sample_false
|
141
143
|
|
142
144
|
self.in_queue = mp.Queue()
|
143
145
|
self.out_queue = mp.Queue()
|
@@ -292,6 +294,7 @@ class HFRunner:
|
|
292
294
|
torch_dtype=torch_dtype,
|
293
295
|
output_str_only=self.output_str_only,
|
294
296
|
token_ids_logprob=token_ids_logprob,
|
297
|
+
patch_model_do_sample_false=self.patch_model_do_sample_false,
|
295
298
|
)
|
296
299
|
)
|
297
300
|
elif self.model_type == "embedding":
|
@@ -380,6 +383,7 @@ class HFRunner:
|
|
380
383
|
lora_paths: Optional[List[str]] = None,
|
381
384
|
output_str_only: bool = False,
|
382
385
|
token_ids_logprob: Optional[int] = None,
|
386
|
+
patch_model_do_sample_false: Optional[bool] = False,
|
383
387
|
) -> ModelOutput:
|
384
388
|
output_strs = []
|
385
389
|
top_input_logprobs = []
|
@@ -407,7 +411,8 @@ class HFRunner:
|
|
407
411
|
)
|
408
412
|
else:
|
409
413
|
model = base_model
|
410
|
-
|
414
|
+
if patch_model_do_sample_false:
|
415
|
+
model.generation_config.do_sample = False
|
411
416
|
outputs = model.generate(
|
412
417
|
input_ids=input_ids,
|
413
418
|
generation_config=GenerationConfig(
|
@@ -481,7 +486,7 @@ class SRTRunner:
|
|
481
486
|
torch_dtype: torch.dtype,
|
482
487
|
model_type: str,
|
483
488
|
tp_size: int = 1,
|
484
|
-
|
489
|
+
model_impl: str = "auto",
|
485
490
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
486
491
|
lora_paths: List[str] = None,
|
487
492
|
max_loras_per_batch: int = 4,
|
@@ -505,6 +510,9 @@ class SRTRunner:
|
|
505
510
|
torchao_config: Optional[str] = None,
|
506
511
|
cuda_graph_max_bs: int = 4,
|
507
512
|
sleep_on_idle=False,
|
513
|
+
max_lora_rank: Optional[int] = None,
|
514
|
+
lora_target_modules: Optional[List[str]] = None,
|
515
|
+
enable_lora: Optional[bool] = None,
|
508
516
|
):
|
509
517
|
self.model_type = model_type
|
510
518
|
self.is_generation = model_type == "generation"
|
@@ -523,7 +531,7 @@ class SRTRunner:
|
|
523
531
|
tp_size=tp_size,
|
524
532
|
dtype=get_dtype_str(torch_dtype),
|
525
533
|
port=port,
|
526
|
-
|
534
|
+
model_impl=model_impl,
|
527
535
|
torchao_config=torchao_config,
|
528
536
|
mem_fraction_static=mem_fraction_static,
|
529
537
|
trust_remote_code=trust_remote_code,
|
@@ -543,6 +551,9 @@ class SRTRunner:
|
|
543
551
|
cuda_graph_max_bs=cuda_graph_max_bs,
|
544
552
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
545
553
|
sleep_on_idle=sleep_on_idle,
|
554
|
+
max_lora_rank=max_lora_rank,
|
555
|
+
lora_target_modules=lora_target_modules,
|
556
|
+
enable_lora=enable_lora,
|
546
557
|
**spec_kwargs,
|
547
558
|
)
|
548
559
|
|
sglang/test/test_activation.py
CHANGED
@@ -3,9 +3,12 @@ import unittest
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
|
6
|
-
from sglang.srt.layers.activation import GeluAndMul
|
6
|
+
from sglang.srt.layers.activation import GeluAndMul, QuickGELU
|
7
|
+
from sglang.srt.utils import is_hip
|
7
8
|
from sglang.test.test_utils import CustomTestCase
|
8
9
|
|
10
|
+
_is_hip = is_hip()
|
11
|
+
|
9
12
|
|
10
13
|
class TestGeluAndMul(CustomTestCase):
|
11
14
|
DTYPES = [torch.half, torch.bfloat16]
|
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
|
|
52
55
|
self._run_gelu_and_mul_test(*params)
|
53
56
|
|
54
57
|
|
58
|
+
class TestQuickGELU(CustomTestCase):
|
59
|
+
DTYPES = [torch.half, torch.bfloat16]
|
60
|
+
NUM_TOKENS = [7, 83, 2048] # batch = sequence length
|
61
|
+
DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
|
62
|
+
SEEDS = [0]
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def setUpClass(cls):
|
66
|
+
if not torch.cuda.is_available():
|
67
|
+
raise unittest.SkipTest("CUDA is not available")
|
68
|
+
torch.set_default_device("cuda")
|
69
|
+
|
70
|
+
def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
|
71
|
+
torch.manual_seed(seed)
|
72
|
+
|
73
|
+
layer = QuickGELU().to(dtype=dtype)
|
74
|
+
|
75
|
+
x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
|
76
|
+
|
77
|
+
with torch.inference_mode():
|
78
|
+
ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
|
79
|
+
if _is_hip:
|
80
|
+
out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
|
81
|
+
else:
|
82
|
+
out = layer.forward_cuda(x)
|
83
|
+
|
84
|
+
tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
|
85
|
+
self.assertTrue(
|
86
|
+
torch.allclose(out, ref, atol=tol, rtol=tol),
|
87
|
+
msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
|
88
|
+
)
|
89
|
+
print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
|
90
|
+
|
91
|
+
def test_quick_gelu(self):
|
92
|
+
for params in itertools.product(
|
93
|
+
self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
|
94
|
+
):
|
95
|
+
with self.subTest(
|
96
|
+
num_tokens=params[0],
|
97
|
+
dim=params[1],
|
98
|
+
dtype=params[2],
|
99
|
+
seed=params[3],
|
100
|
+
):
|
101
|
+
self._run_gelu_quick_test(*params)
|
102
|
+
|
103
|
+
|
55
104
|
if __name__ == "__main__":
|
56
105
|
unittest.main(verbosity=2)
|
sglang/test/test_block_fp8.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6
6
|
|
7
7
|
from sglang.srt.layers.activation import SiluAndMul
|
8
8
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
9
|
+
from sglang.srt.layers.moe.topk import select_experts
|
9
10
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
10
11
|
per_tensor_quant_mla_fp8,
|
11
12
|
per_token_group_quant_fp8,
|
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|
497
498
|
score = torch.randn((M, E), dtype=dtype)
|
498
499
|
|
499
500
|
with torch.inference_mode():
|
501
|
+
topk_output = select_experts(
|
502
|
+
hidden_states=a,
|
503
|
+
router_logits=score,
|
504
|
+
top_k=topk,
|
505
|
+
renormalize=False,
|
506
|
+
)
|
500
507
|
out = fused_moe(
|
501
508
|
a,
|
502
509
|
w1,
|
503
510
|
w2,
|
504
|
-
|
505
|
-
topk,
|
506
|
-
renormalize=False,
|
511
|
+
topk_output,
|
507
512
|
use_fp8_w8a8=True,
|
508
513
|
w1_scale=w1_s,
|
509
514
|
w2_scale=w2_s,
|
sglang/test/test_block_fp8_ep.py
CHANGED
@@ -40,7 +40,7 @@ def ep_moe(
|
|
40
40
|
block_shape: Optional[List[int]] = None,
|
41
41
|
):
|
42
42
|
use_blockwise_fp8 = block_shape is not None
|
43
|
-
topk_weights, topk_ids = select_experts(
|
43
|
+
topk_weights, topk_ids, _ = select_experts(
|
44
44
|
hidden_states=hidden_states,
|
45
45
|
router_logits=router_logits,
|
46
46
|
top_k=top_k,
|
sglang/test/test_custom_ops.py
CHANGED
@@ -3,8 +3,13 @@
|
|
3
3
|
import pytest
|
4
4
|
import torch
|
5
5
|
|
6
|
-
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
7
|
-
from sglang.srt.utils import is_cuda
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
7
|
+
from sglang.srt.utils import is_cuda, is_hip
|
8
|
+
|
9
|
+
_is_cuda = is_cuda()
|
10
|
+
_is_hip = is_hip()
|
11
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
12
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
|
8
13
|
|
9
14
|
|
10
15
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
|
13
18
|
def quantize_ref_per_tensor(tensor, inv_scale):
|
14
19
|
# The reference implementation that fully aligns to
|
15
20
|
# the kernel being tested.
|
16
|
-
finfo = torch.finfo(
|
21
|
+
finfo = torch.finfo(fp8_dtype)
|
17
22
|
scale = inv_scale.reciprocal()
|
18
23
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
19
|
-
qweight = qweight.to(
|
24
|
+
qweight = qweight.to(fp8_dtype)
|
20
25
|
return qweight
|
21
26
|
|
22
27
|
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
|
48
53
|
)
|
49
54
|
|
50
55
|
|
51
|
-
if
|
56
|
+
if _is_cuda or _is_hip:
|
52
57
|
|
53
58
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
54
59
|
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
55
60
|
def quantize_ref_per_token(tensor, inv_scale):
|
56
61
|
# The reference implementation that fully aligns to
|
57
62
|
# the kernel being tested.
|
58
|
-
finfo = torch.finfo(
|
63
|
+
finfo = torch.finfo(fp8_dtype)
|
59
64
|
scale = inv_scale.reciprocal()
|
60
65
|
qweight = (tensor.to(torch.float32) * scale).clamp(
|
61
66
|
min=finfo.min, max=finfo.max
|
62
67
|
)
|
63
|
-
qweight = qweight.to(
|
68
|
+
qweight = qweight.to(fp8_dtype)
|
64
69
|
return qweight
|
65
70
|
|
66
71
|
def dequantize_per_token(tensor, inv_scale, dtype):
|
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
|
100
100
|
s_strides2 = c_strides2
|
101
101
|
|
102
102
|
score = torch.randn((M, E), dtype=dtype, device=device)
|
103
|
-
topk_weights, topk_ids = select_experts(
|
103
|
+
topk_weights, topk_ids, _ = select_experts(
|
104
104
|
hidden_states=a,
|
105
105
|
router_logits=score,
|
106
106
|
top_k=topk,
|
107
|
-
use_grouped_topk=False,
|
108
|
-
renormalize=False,
|
109
107
|
)
|
110
108
|
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
111
109
|
expert_map[local_e:] = E
|
sglang/test/test_fp4_moe.py
CHANGED
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
|
|
159
159
|
|
160
160
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
161
161
|
|
162
|
-
topk_weights, topk_ids = select_experts(
|
162
|
+
topk_weights, topk_ids, _ = select_experts(
|
163
163
|
hidden_states=a,
|
164
164
|
router_logits=score,
|
165
165
|
top_k=topk,
|
166
|
-
use_grouped_topk=False,
|
167
|
-
renormalize=False,
|
168
166
|
)
|
169
167
|
|
170
168
|
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|