sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -172,6 +172,20 @@ def is_blackwell():
|
|
172
172
|
return torch.cuda.get_device_capability()[0] == 10
|
173
173
|
|
174
174
|
|
175
|
+
@lru_cache(maxsize=1)
|
176
|
+
def is_sm100_supported(device=None) -> bool:
|
177
|
+
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
178
|
+
torch.version.cuda >= "12.8"
|
179
|
+
)
|
180
|
+
|
181
|
+
|
182
|
+
@lru_cache(maxsize=1)
|
183
|
+
def is_sm90_supported(device=None) -> bool:
|
184
|
+
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
185
|
+
torch.version.cuda >= "12.3"
|
186
|
+
)
|
187
|
+
|
188
|
+
|
175
189
|
_warned_bool_env_var_keys = set()
|
176
190
|
|
177
191
|
|
@@ -216,8 +230,16 @@ except:
|
|
216
230
|
is_intel_amx_backend_available = False
|
217
231
|
|
218
232
|
|
233
|
+
try:
|
234
|
+
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
|
235
|
+
# to support torch compile
|
236
|
+
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
|
237
|
+
except:
|
238
|
+
is_amx_tile_supported = False
|
239
|
+
|
240
|
+
|
219
241
|
def cpu_has_amx_support():
|
220
|
-
return
|
242
|
+
return is_amx_tile_supported and is_intel_amx_backend_available
|
221
243
|
|
222
244
|
|
223
245
|
def use_intel_amx_backend(layer):
|
@@ -412,7 +434,9 @@ def get_available_gpu_memory(
|
|
412
434
|
|
413
435
|
elif device == "cpu":
|
414
436
|
# TODO: rename the variables in the current function to be not GPU specific
|
415
|
-
|
437
|
+
total_free_memory = psutil.virtual_memory().available
|
438
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
439
|
+
free_gpu_memory = round(total_free_memory / n_numa_node, 3)
|
416
440
|
elif device == "npu":
|
417
441
|
num_gpus = torch.npu.device_count()
|
418
442
|
assert gpu_id < num_gpus
|
@@ -1665,9 +1689,29 @@ def direct_register_custom_op(
|
|
1665
1689
|
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
1666
1690
|
library object. If you want to bind the operator to a different library,
|
1667
1691
|
make sure the library object is alive when the operator is used.
|
1692
|
+
|
1693
|
+
Note: This function will silently skip registration if the operator
|
1694
|
+
with the same name is already registered to avoid RuntimeError in
|
1695
|
+
multi-engine scenarios (e.g., VERL framework).
|
1668
1696
|
"""
|
1669
1697
|
import torch.library
|
1670
1698
|
|
1699
|
+
my_lib = target_lib or sglang_lib
|
1700
|
+
|
1701
|
+
# Check if operator is already registered to avoid duplicate registration
|
1702
|
+
# This is important for scenarios where multiple SGLang engines run in the same process
|
1703
|
+
try:
|
1704
|
+
# Try to access the operator to see if it's already registered
|
1705
|
+
lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang"
|
1706
|
+
if hasattr(torch.ops, lib_name) and hasattr(
|
1707
|
+
getattr(torch.ops, lib_name), op_name
|
1708
|
+
):
|
1709
|
+
# Operator already exists, skip registration
|
1710
|
+
return
|
1711
|
+
except (AttributeError, RuntimeError):
|
1712
|
+
# Operator doesn't exist, proceed with registration
|
1713
|
+
pass
|
1714
|
+
|
1671
1715
|
if hasattr(torch.library, "infer_schema"):
|
1672
1716
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
1673
1717
|
else:
|
@@ -1676,11 +1720,22 @@ def direct_register_custom_op(
|
|
1676
1720
|
|
1677
1721
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
1678
1722
|
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1723
|
+
try:
|
1724
|
+
my_lib.define(op_name + schema_str)
|
1725
|
+
my_lib.impl(op_name, op_func, "CUDA")
|
1726
|
+
if fake_impl is not None:
|
1727
|
+
my_lib._register_fake(op_name, fake_impl)
|
1728
|
+
except RuntimeError as error:
|
1729
|
+
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
|
1730
|
+
# Silently ignore duplicate registration errors
|
1731
|
+
# This can happen in multi-engine scenarios
|
1732
|
+
pass
|
1733
|
+
else:
|
1734
|
+
# Re-raise other RuntimeErrors
|
1735
|
+
raise error
|
1736
|
+
except AttributeError as error:
|
1737
|
+
# Always re-raise AttributeError as it indicates missing dependencies
|
1738
|
+
raise error
|
1684
1739
|
|
1685
1740
|
|
1686
1741
|
def set_gpu_proc_affinity(
|
@@ -1919,6 +1974,15 @@ def get_ip() -> str:
|
|
1919
1974
|
except Exception:
|
1920
1975
|
pass
|
1921
1976
|
|
1977
|
+
# try using hostname
|
1978
|
+
hostname = socket.gethostname()
|
1979
|
+
try:
|
1980
|
+
ip_addr = socket.gethostbyname(hostname)
|
1981
|
+
warnings.warn("using local ip address: {}".format(ip_addr))
|
1982
|
+
return ip_addr
|
1983
|
+
except Exception:
|
1984
|
+
pass
|
1985
|
+
|
1922
1986
|
warnings.warn(
|
1923
1987
|
"Failed to get the IP address, using 0.0.0.0 by default."
|
1924
1988
|
"The value can be set by the environment variable"
|
@@ -2733,6 +2797,10 @@ def lru_cache_frozenset(maxsize=128):
|
|
2733
2797
|
return decorator
|
2734
2798
|
|
2735
2799
|
|
2800
|
+
def get_origin_rid(rid):
|
2801
|
+
return rid.split("_", 1)[1] if "_" in rid else rid
|
2802
|
+
|
2803
|
+
|
2736
2804
|
def apply_module_patch(target_module, target_function, wrappers):
|
2737
2805
|
original_module, original_function = parse_module_path(
|
2738
2806
|
target_module, target_function, False
|
@@ -2842,6 +2910,18 @@ def mxfp_supported():
|
|
2842
2910
|
return False
|
2843
2911
|
|
2844
2912
|
|
2913
|
+
@lru_cache(maxsize=1)
|
2914
|
+
def is_gfx95_supported():
|
2915
|
+
"""
|
2916
|
+
Returns whether the current platform supports MX types.
|
2917
|
+
"""
|
2918
|
+
if torch.version.hip:
|
2919
|
+
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
2920
|
+
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
2921
|
+
else:
|
2922
|
+
return False
|
2923
|
+
|
2924
|
+
|
2845
2925
|
# LoRA-related constants and utilities
|
2846
2926
|
SUPPORTED_LORA_TARGET_MODULES = [
|
2847
2927
|
"q_proj",
|
@@ -2957,3 +3037,12 @@ def check_cuda_result(raw_output):
|
|
2957
3037
|
raise Exception(f"CUDA error: {err}")
|
2958
3038
|
|
2959
3039
|
return results
|
3040
|
+
|
3041
|
+
|
3042
|
+
def numa_bind_to_node(node: int):
|
3043
|
+
libnuma = ctypes.CDLL("libnuma.so")
|
3044
|
+
if libnuma.numa_available() < 0:
|
3045
|
+
raise SystemError("numa not available on this system")
|
3046
|
+
|
3047
|
+
libnuma.numa_run_on_node(ctypes.c_int(node))
|
3048
|
+
libnuma.numa_set_localalloc()
|
sglang/srt/weight_sync/utils.py
CHANGED
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
|
|
6
6
|
from torch.distributed.tensor import DTensor
|
7
7
|
|
8
8
|
from sglang.srt.entrypoints.engine import Engine
|
9
|
-
from sglang.srt.managers.
|
9
|
+
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
|
10
10
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
11
|
from sglang.srt.utils import MultiprocessingSerializer
|
12
12
|
|
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
|
41
41
|
"v_head_dim": 512,
|
42
42
|
"num_kv_heads": 1,
|
43
43
|
"layer_id": 0,
|
44
|
+
"tp_q_head_num": 128,
|
45
|
+
"tp_k_head_num": 128,
|
46
|
+
"prefill_head_dim": 192,
|
47
|
+
"prefill_v_head_dim": 128,
|
44
48
|
}
|
45
49
|
|
46
50
|
ROPE_BASE = 10000
|
@@ -92,7 +96,7 @@ TEST_CASES = {
|
|
92
96
|
"description": "Medium-scale batch",
|
93
97
|
},
|
94
98
|
],
|
95
|
-
"
|
99
|
+
"output_match": [
|
96
100
|
{
|
97
101
|
"name": "single_fp16",
|
98
102
|
"batch_size": 1,
|
@@ -208,6 +212,15 @@ class MockModelRunner:
|
|
208
212
|
self.kv_cache_dtype = config["kv_cache_dtype"]
|
209
213
|
self.page_size = config["page_size"]
|
210
214
|
|
215
|
+
# Server args stub - needed by attention backends
|
216
|
+
self.server_args = type(
|
217
|
+
"ServerArgs",
|
218
|
+
(),
|
219
|
+
{
|
220
|
+
"enable_dp_attention": False, # Default value for testing
|
221
|
+
},
|
222
|
+
)
|
223
|
+
|
211
224
|
# Model-config stub with MLA attributes
|
212
225
|
self.model_config = type(
|
213
226
|
"ModelConfig",
|
@@ -313,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
313
326
|
config.update(test_case)
|
314
327
|
return config
|
315
328
|
|
316
|
-
def _create_model_components(self, config):
|
329
|
+
def _create_model_components(self, config, is_prefill=False):
|
317
330
|
"""Create model runners, backends, and layer for testing."""
|
318
331
|
# Create model runners
|
319
332
|
model_runner_trtllm = MockModelRunner(config)
|
@@ -323,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
323
336
|
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
324
337
|
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
325
338
|
|
339
|
+
head_dim = (
|
340
|
+
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
341
|
+
if not is_prefill
|
342
|
+
else config["prefill_head_dim"]
|
343
|
+
)
|
344
|
+
v_head_dim = (
|
345
|
+
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
346
|
+
)
|
347
|
+
|
326
348
|
# Create RadixAttention layer
|
327
349
|
layer = RadixAttention(
|
328
350
|
num_heads=config["num_attention_heads"],
|
329
|
-
head_dim=
|
351
|
+
head_dim=head_dim,
|
330
352
|
scaling=model_runner_trtllm.model_config.scaling,
|
331
353
|
num_kv_heads=config["num_kv_heads"],
|
332
354
|
layer_id=config["layer_id"],
|
333
|
-
v_head_dim=
|
355
|
+
v_head_dim=v_head_dim,
|
334
356
|
prefix="attn_mqa",
|
335
357
|
)
|
336
358
|
|
@@ -515,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
515
537
|
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
516
538
|
print(f"\nRunning decode output matching tests...")
|
517
539
|
|
518
|
-
for test_case in TEST_CASES["
|
540
|
+
for test_case in TEST_CASES["output_match"]:
|
519
541
|
with self.subTest(test_case=test_case["name"]):
|
520
542
|
print(f" Testing {test_case['name']}: {test_case['description']}")
|
521
543
|
|
@@ -833,7 +855,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
833
855
|
|
834
856
|
# Test workspace properties
|
835
857
|
self.assertEqual(metadata.workspace.device.type, "cuda")
|
836
|
-
self.assertEqual(metadata.workspace.dtype, torch.
|
858
|
+
self.assertEqual(metadata.workspace.dtype, torch.uint8)
|
837
859
|
self.assertGreater(
|
838
860
|
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
839
861
|
)
|
@@ -993,8 +1015,8 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
993
1015
|
)
|
994
1016
|
|
995
1017
|
# Verify CUDA graph buffers are allocated
|
996
|
-
self.assertIsNotNone(backend.
|
997
|
-
self.assertIsNotNone(backend.
|
1018
|
+
self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
|
1019
|
+
self.assertIsNotNone(backend.decode_cuda_graph_workspace)
|
998
1020
|
|
999
1021
|
# Test capture metadata
|
1000
1022
|
seq_lens = torch.full(
|
@@ -1090,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
1090
1112
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
1091
1113
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
1092
1114
|
|
1115
|
+
def test_prefill_output_match_self_attention(self):
|
1116
|
+
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
1117
|
+
print(f"\nRunning prefill output tests...")
|
1118
|
+
|
1119
|
+
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
1120
|
+
with self.subTest(test_case=test_case["name"]):
|
1121
|
+
print(
|
1122
|
+
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
config = self._merge_config(test_case)
|
1126
|
+
batch_size = config["batch_size"]
|
1127
|
+
max_seq_len = config["max_seq_len"]
|
1128
|
+
|
1129
|
+
# Create components
|
1130
|
+
(
|
1131
|
+
model_runner_trtllm,
|
1132
|
+
model_runner_reference,
|
1133
|
+
trtllm_backend,
|
1134
|
+
reference_backend,
|
1135
|
+
layer,
|
1136
|
+
) = self._create_model_components(config, is_prefill=True)
|
1137
|
+
|
1138
|
+
# Prefill uses full sequences
|
1139
|
+
seq_lens = torch.full(
|
1140
|
+
(batch_size,), max_seq_len, device=config["device"]
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
def _create_forward_batch_prefill(
|
1144
|
+
batch_size,
|
1145
|
+
seq_lens,
|
1146
|
+
extend_prefix_lens,
|
1147
|
+
backend,
|
1148
|
+
model_runner,
|
1149
|
+
config,
|
1150
|
+
):
|
1151
|
+
"""Create a forward batch for the given backend."""
|
1152
|
+
|
1153
|
+
fb = ForwardBatch(
|
1154
|
+
batch_size=batch_size,
|
1155
|
+
input_ids=torch.randint(
|
1156
|
+
0, 100, (batch_size, 1), device=config["device"]
|
1157
|
+
),
|
1158
|
+
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
1159
|
+
seq_lens_sum=int(seq_lens.sum().item()),
|
1160
|
+
extend_prefix_lens=extend_prefix_lens,
|
1161
|
+
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
1162
|
+
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
1163
|
+
.cpu()
|
1164
|
+
.int()
|
1165
|
+
.tolist(),
|
1166
|
+
forward_mode=ForwardMode.EXTEND,
|
1167
|
+
req_pool_indices=torch.arange(
|
1168
|
+
batch_size, device=config["device"]
|
1169
|
+
),
|
1170
|
+
seq_lens=seq_lens,
|
1171
|
+
seq_lens_cpu=seq_lens.cpu(),
|
1172
|
+
attn_attend_prefix_cache=False,
|
1173
|
+
mha_return_lse=False,
|
1174
|
+
attn_backend=backend,
|
1175
|
+
)
|
1176
|
+
fb.req_to_token_pool = model_runner.req_to_token_pool
|
1177
|
+
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
1178
|
+
|
1179
|
+
# Add position information for RoPE
|
1180
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
1181
|
+
|
1182
|
+
return fb
|
1183
|
+
|
1184
|
+
# Create forward batches
|
1185
|
+
fb_trtllm = _create_forward_batch_prefill(
|
1186
|
+
batch_size,
|
1187
|
+
seq_lens.clone(),
|
1188
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1189
|
+
trtllm_backend,
|
1190
|
+
model_runner_trtllm,
|
1191
|
+
config,
|
1192
|
+
)
|
1193
|
+
fb_reference = _create_forward_batch_prefill(
|
1194
|
+
batch_size,
|
1195
|
+
seq_lens.clone(),
|
1196
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1197
|
+
reference_backend,
|
1198
|
+
model_runner_reference,
|
1199
|
+
config,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
# Initialize metadata for both backends
|
1203
|
+
trtllm_backend.init_forward_metadata(fb_trtllm)
|
1204
|
+
reference_backend.init_forward_metadata(fb_reference)
|
1205
|
+
|
1206
|
+
# Create Q, K, V tensors for prefill
|
1207
|
+
torch.manual_seed(config["seed_qkv"])
|
1208
|
+
|
1209
|
+
def _create_qkv_tensors_prefill(
|
1210
|
+
batch_size, seq_len, config, dtype_override=None
|
1211
|
+
):
|
1212
|
+
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
1213
|
+
device = config["device"]
|
1214
|
+
dtype = dtype_override or config["dtype"]
|
1215
|
+
|
1216
|
+
total_tokens = batch_size * seq_len
|
1217
|
+
|
1218
|
+
tp_q_head_num = config["tp_q_head_num"]
|
1219
|
+
tp_k_head_num = config["tp_k_head_num"]
|
1220
|
+
head_dim = config["prefill_head_dim"]
|
1221
|
+
v_head_dim = config["prefill_v_head_dim"]
|
1222
|
+
|
1223
|
+
q = torch.randn(
|
1224
|
+
(total_tokens, tp_q_head_num * head_dim),
|
1225
|
+
dtype=dtype,
|
1226
|
+
device=device,
|
1227
|
+
)
|
1228
|
+
k = torch.randn(
|
1229
|
+
(total_tokens, tp_k_head_num * head_dim),
|
1230
|
+
dtype=dtype,
|
1231
|
+
device=device,
|
1232
|
+
)
|
1233
|
+
v = torch.randn(
|
1234
|
+
(total_tokens, tp_k_head_num * v_head_dim),
|
1235
|
+
dtype=dtype,
|
1236
|
+
device=device,
|
1237
|
+
)
|
1238
|
+
|
1239
|
+
# Reshape as requested
|
1240
|
+
q = q.view(-1, tp_q_head_num, head_dim)
|
1241
|
+
k = k.view(-1, tp_k_head_num, head_dim)
|
1242
|
+
v = v.view(-1, tp_k_head_num, v_head_dim)
|
1243
|
+
|
1244
|
+
return q, k, v
|
1245
|
+
|
1246
|
+
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
1247
|
+
# Run prefill on both backends
|
1248
|
+
out_trtllm = trtllm_backend.forward_extend(
|
1249
|
+
q, k, v, layer, fb_trtllm, False
|
1250
|
+
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1251
|
+
out_reference = reference_backend.forward_extend(
|
1252
|
+
q, k, v, layer, fb_reference, False
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
tolerance = config.get("tolerance", 1e-2)
|
1256
|
+
comparison_passed = compare_outputs(
|
1257
|
+
out_trtllm, out_reference, tolerance=tolerance
|
1258
|
+
)
|
1259
|
+
self.assertTrue(
|
1260
|
+
comparison_passed,
|
1261
|
+
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
1262
|
+
f"Config: {test_case['name']}, "
|
1263
|
+
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
1264
|
+
)
|
1265
|
+
|
1093
1266
|
|
1094
1267
|
if __name__ == "__main__":
|
1095
1268
|
unittest.main()
|
sglang/test/few_shot_gsm8k.py
CHANGED
sglang/test/runners.py
CHANGED
@@ -505,6 +505,7 @@ class SRTRunner:
|
|
505
505
|
mem_fraction_static: float = 0.65,
|
506
506
|
trust_remote_code: bool = False,
|
507
507
|
speculative_draft_model_path: Optional[str] = None,
|
508
|
+
speculative_draft_model_revision: Optional[str] = None,
|
508
509
|
speculative_algorithm: Optional[str] = None,
|
509
510
|
speculative_num_steps: Optional[int] = None,
|
510
511
|
speculative_eagle_topk: Optional[int] = None,
|
@@ -526,6 +527,9 @@ class SRTRunner:
|
|
526
527
|
spec_kwargs = {}
|
527
528
|
if speculative_draft_model_path:
|
528
529
|
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
530
|
+
spec_kwargs["speculative_draft_model_revision"] = (
|
531
|
+
speculative_draft_model_revision
|
532
|
+
)
|
529
533
|
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
530
534
|
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
531
535
|
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
sglang/test/test_cutlass_moe.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoConfig
|
|
9
9
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
10
10
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
11
11
|
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
12
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
@@ -21,7 +22,7 @@ def calc_diff(x, y):
|
|
21
22
|
|
22
23
|
def get_model_config(tp_size: int):
|
23
24
|
config = AutoConfig.from_pretrained(
|
24
|
-
"deepseek-ai/
|
25
|
+
"deepseek-ai/Deepseek-R1", trust_remote_code=True
|
25
26
|
)
|
26
27
|
E = config.n_routed_experts
|
27
28
|
topk = config.num_experts_per_tok
|
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
152
153
|
problem_sizes2,
|
153
154
|
)
|
154
155
|
|
156
|
+
topk_output = StandardTopKOutput(
|
157
|
+
topk_weights=topk_weights,
|
158
|
+
topk_ids=topk_ids,
|
159
|
+
router_logits=torch.randn(
|
160
|
+
(batch_size, topk), device=topk_weights.device, dtype=dtype
|
161
|
+
),
|
162
|
+
)
|
163
|
+
|
164
|
+
moe_runner_config = MoeRunnerConfig(
|
165
|
+
num_experts=E,
|
166
|
+
top_k=topk,
|
167
|
+
hidden_size=H,
|
168
|
+
intermediate_size_per_partition=I,
|
169
|
+
params_dtype=dtype,
|
170
|
+
activation="silu",
|
171
|
+
inplace=False,
|
172
|
+
)
|
173
|
+
|
155
174
|
# Note: Triton expects non-transposed weights
|
156
|
-
moe_config = MoeRunnerConfig(inplace=False)
|
157
175
|
triton_lambda = lambda: fused_experts(
|
158
176
|
x,
|
159
177
|
w1,
|
160
178
|
w2,
|
161
|
-
|
162
|
-
|
179
|
+
topk_output,
|
180
|
+
moe_runner_config,
|
163
181
|
use_fp8_w8a8=True,
|
164
182
|
w1_scale=w1_scale,
|
165
183
|
w2_scale=w2_scale,
|
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
224
242
|
x,
|
225
243
|
w1, # Original shape
|
226
244
|
w2, # Original shape
|
227
|
-
|
228
|
-
|
245
|
+
topk_output,
|
246
|
+
moe_runner_config,
|
229
247
|
use_fp8_w8a8=True,
|
230
248
|
w1_scale=w1_scale,
|
231
249
|
w2_scale=w2_scale,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
2
|
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Literal, Optional
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten
|
|
25
25
|
return packed_tensor.to(torch.int8)
|
26
26
|
|
27
27
|
|
28
|
-
def pack_interleave(num_experts, ref_weight, ref_scale):
|
28
|
+
def pack_interleave(num_experts, ref_weight, ref_scale, alignment=4):
|
29
29
|
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
30
30
|
|
31
31
|
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
|
|
33
33
|
w_q = w_q.contiguous()
|
34
34
|
|
35
35
|
scale_interleaved = ref_scale.reshape(
|
36
|
-
ref_scale.shape[0],
|
36
|
+
ref_scale.shape[0],
|
37
|
+
ref_scale.shape[1],
|
38
|
+
(ref_scale.shape[2] // alignment),
|
39
|
+
alignment,
|
37
40
|
) # [E, N, K/4, 4]
|
38
41
|
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
39
42
|
scale_interleaved = scale_interleaved.reshape(
|
40
|
-
ref_scale.shape[0],
|
43
|
+
ref_scale.shape[0],
|
44
|
+
ref_scale.shape[2] // alignment,
|
45
|
+
ref_scale.shape[1] * alignment,
|
41
46
|
) # [E, K/4, N*4]
|
42
47
|
w_scale = scale_interleaved.contiguous()
|
43
48
|
|
@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
|
|
48
53
|
@pytest.mark.parametrize("N", [2048])
|
49
54
|
@pytest.mark.parametrize("K", [7168])
|
50
55
|
@pytest.mark.parametrize("E", [256])
|
51
|
-
@pytest.mark.parametrize("
|
56
|
+
@pytest.mark.parametrize("tp_size", [8])
|
57
|
+
@pytest.mark.parametrize("use_ep_moe", [True, False])
|
52
58
|
@pytest.mark.parametrize("topk", [8])
|
53
59
|
@pytest.mark.parametrize("group_size", [128])
|
54
60
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
55
|
-
def test_cutlass_w4a8_moe(M, N, K, E,
|
56
|
-
|
61
|
+
def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dtype):
|
62
|
+
if use_ep_moe:
|
63
|
+
local_e = E // tp_size
|
64
|
+
else: # tp mode
|
65
|
+
local_e = E
|
66
|
+
N = N // tp_size
|
57
67
|
|
58
68
|
debug = False
|
59
69
|
if debug:
|
@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
|
87
97
|
)
|
88
98
|
|
89
99
|
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
90
|
-
|
100
|
+
if use_ep_moe:
|
101
|
+
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
102
|
+
else:
|
103
|
+
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2, 1)
|
91
104
|
|
92
105
|
device = "cuda"
|
93
106
|
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
@@ -265,7 +278,9 @@ def ref(
|
|
265
278
|
|
266
279
|
gate, fc1 = fc1.chunk(2, dim=-1)
|
267
280
|
fc1 = fc1 * torch.nn.functional.silu(gate)
|
268
|
-
act = (fc1 / pre_quant_scale_2.float()).to(
|
281
|
+
act = torch.clamp((fc1 / pre_quant_scale_2.float()), -448.0, 448.0).to(
|
282
|
+
torch.float8_e4m3fn
|
283
|
+
)
|
269
284
|
act = act.to(dtype)
|
270
285
|
|
271
286
|
w2 = ref_weight_2[e_idx]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import requests
|
4
|
+
|
5
|
+
from sglang.srt.utils import kill_process_tree
|
6
|
+
from sglang.test.test_utils import (
|
7
|
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
8
|
+
CustomTestCase,
|
9
|
+
popen_with_error_check,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class TestDisaggregationBase(CustomTestCase):
|
14
|
+
@classmethod
|
15
|
+
def setUpClass(cls):
|
16
|
+
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
|
17
|
+
pass
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def launch_lb(cls):
|
21
|
+
lb_command = [
|
22
|
+
"python3",
|
23
|
+
"-m",
|
24
|
+
"sglang_router.launch_router",
|
25
|
+
"--pd-disaggregation",
|
26
|
+
"--mini-lb", # FIXME: remove this
|
27
|
+
"--prefill",
|
28
|
+
cls.prefill_url,
|
29
|
+
"--decode",
|
30
|
+
cls.decode_url,
|
31
|
+
"--host",
|
32
|
+
cls.base_host,
|
33
|
+
"--port",
|
34
|
+
cls.lb_port,
|
35
|
+
]
|
36
|
+
print("Starting load balancer:", " ".join(lb_command))
|
37
|
+
cls.process_lb = popen_with_error_check(lb_command)
|
38
|
+
cls.wait_server_ready(cls.lb_url + "/health")
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
|
42
|
+
start_time = time.perf_counter()
|
43
|
+
while True:
|
44
|
+
try:
|
45
|
+
response = requests.get(url)
|
46
|
+
if response.status_code == 200:
|
47
|
+
print(f"Server {url} is ready")
|
48
|
+
return
|
49
|
+
except Exception:
|
50
|
+
pass
|
51
|
+
|
52
|
+
if time.perf_counter() - start_time > timeout:
|
53
|
+
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
54
|
+
time.sleep(1)
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def tearDownClass(cls):
|
58
|
+
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
59
|
+
if process:
|
60
|
+
try:
|
61
|
+
kill_process_tree(process.pid)
|
62
|
+
except Exception as e:
|
63
|
+
print(f"Error killing process {process.pid}: {e}")
|
64
|
+
|
65
|
+
# wait for 5 seconds
|
66
|
+
time.sleep(5)
|