sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -230,8 +230,16 @@ except:
|
|
230
230
|
is_intel_amx_backend_available = False
|
231
231
|
|
232
232
|
|
233
|
+
try:
|
234
|
+
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
|
235
|
+
# to support torch compile
|
236
|
+
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
|
237
|
+
except:
|
238
|
+
is_amx_tile_supported = False
|
239
|
+
|
240
|
+
|
233
241
|
def cpu_has_amx_support():
|
234
|
-
return
|
242
|
+
return is_amx_tile_supported and is_intel_amx_backend_available
|
235
243
|
|
236
244
|
|
237
245
|
def use_intel_amx_backend(layer):
|
@@ -426,7 +434,9 @@ def get_available_gpu_memory(
|
|
426
434
|
|
427
435
|
elif device == "cpu":
|
428
436
|
# TODO: rename the variables in the current function to be not GPU specific
|
429
|
-
|
437
|
+
total_free_memory = psutil.virtual_memory().available
|
438
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
439
|
+
free_gpu_memory = round(total_free_memory / n_numa_node, 3)
|
430
440
|
elif device == "npu":
|
431
441
|
num_gpus = torch.npu.device_count()
|
432
442
|
assert gpu_id < num_gpus
|
@@ -2787,6 +2797,10 @@ def lru_cache_frozenset(maxsize=128):
|
|
2787
2797
|
return decorator
|
2788
2798
|
|
2789
2799
|
|
2800
|
+
def get_origin_rid(rid):
|
2801
|
+
return rid.split("_", 1)[1] if "_" in rid else rid
|
2802
|
+
|
2803
|
+
|
2790
2804
|
def apply_module_patch(target_module, target_function, wrappers):
|
2791
2805
|
original_module, original_function = parse_module_path(
|
2792
2806
|
target_module, target_function, False
|
@@ -2896,6 +2910,18 @@ def mxfp_supported():
|
|
2896
2910
|
return False
|
2897
2911
|
|
2898
2912
|
|
2913
|
+
@lru_cache(maxsize=1)
|
2914
|
+
def is_gfx95_supported():
|
2915
|
+
"""
|
2916
|
+
Returns whether the current platform supports MX types.
|
2917
|
+
"""
|
2918
|
+
if torch.version.hip:
|
2919
|
+
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
2920
|
+
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
2921
|
+
else:
|
2922
|
+
return False
|
2923
|
+
|
2924
|
+
|
2899
2925
|
# LoRA-related constants and utilities
|
2900
2926
|
SUPPORTED_LORA_TARGET_MODULES = [
|
2901
2927
|
"q_proj",
|
@@ -3011,3 +3037,12 @@ def check_cuda_result(raw_output):
|
|
3011
3037
|
raise Exception(f"CUDA error: {err}")
|
3012
3038
|
|
3013
3039
|
return results
|
3040
|
+
|
3041
|
+
|
3042
|
+
def numa_bind_to_node(node: int):
|
3043
|
+
libnuma = ctypes.CDLL("libnuma.so")
|
3044
|
+
if libnuma.numa_available() < 0:
|
3045
|
+
raise SystemError("numa not available on this system")
|
3046
|
+
|
3047
|
+
libnuma.numa_run_on_node(ctypes.c_int(node))
|
3048
|
+
libnuma.numa_set_localalloc()
|
sglang/srt/weight_sync/utils.py
CHANGED
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
|
|
6
6
|
from torch.distributed.tensor import DTensor
|
7
7
|
|
8
8
|
from sglang.srt.entrypoints.engine import Engine
|
9
|
-
from sglang.srt.managers.
|
9
|
+
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
|
10
10
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
11
|
from sglang.srt.utils import MultiprocessingSerializer
|
12
12
|
|
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
|
41
41
|
"v_head_dim": 512,
|
42
42
|
"num_kv_heads": 1,
|
43
43
|
"layer_id": 0,
|
44
|
+
"tp_q_head_num": 128,
|
45
|
+
"tp_k_head_num": 128,
|
46
|
+
"prefill_head_dim": 192,
|
47
|
+
"prefill_v_head_dim": 128,
|
44
48
|
}
|
45
49
|
|
46
50
|
ROPE_BASE = 10000
|
@@ -92,7 +96,7 @@ TEST_CASES = {
|
|
92
96
|
"description": "Medium-scale batch",
|
93
97
|
},
|
94
98
|
],
|
95
|
-
"
|
99
|
+
"output_match": [
|
96
100
|
{
|
97
101
|
"name": "single_fp16",
|
98
102
|
"batch_size": 1,
|
@@ -208,6 +212,15 @@ class MockModelRunner:
|
|
208
212
|
self.kv_cache_dtype = config["kv_cache_dtype"]
|
209
213
|
self.page_size = config["page_size"]
|
210
214
|
|
215
|
+
# Server args stub - needed by attention backends
|
216
|
+
self.server_args = type(
|
217
|
+
"ServerArgs",
|
218
|
+
(),
|
219
|
+
{
|
220
|
+
"enable_dp_attention": False, # Default value for testing
|
221
|
+
},
|
222
|
+
)
|
223
|
+
|
211
224
|
# Model-config stub with MLA attributes
|
212
225
|
self.model_config = type(
|
213
226
|
"ModelConfig",
|
@@ -313,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
313
326
|
config.update(test_case)
|
314
327
|
return config
|
315
328
|
|
316
|
-
def _create_model_components(self, config):
|
329
|
+
def _create_model_components(self, config, is_prefill=False):
|
317
330
|
"""Create model runners, backends, and layer for testing."""
|
318
331
|
# Create model runners
|
319
332
|
model_runner_trtllm = MockModelRunner(config)
|
@@ -323,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
323
336
|
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
324
337
|
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
325
338
|
|
339
|
+
head_dim = (
|
340
|
+
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
341
|
+
if not is_prefill
|
342
|
+
else config["prefill_head_dim"]
|
343
|
+
)
|
344
|
+
v_head_dim = (
|
345
|
+
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
346
|
+
)
|
347
|
+
|
326
348
|
# Create RadixAttention layer
|
327
349
|
layer = RadixAttention(
|
328
350
|
num_heads=config["num_attention_heads"],
|
329
|
-
head_dim=
|
351
|
+
head_dim=head_dim,
|
330
352
|
scaling=model_runner_trtllm.model_config.scaling,
|
331
353
|
num_kv_heads=config["num_kv_heads"],
|
332
354
|
layer_id=config["layer_id"],
|
333
|
-
v_head_dim=
|
355
|
+
v_head_dim=v_head_dim,
|
334
356
|
prefix="attn_mqa",
|
335
357
|
)
|
336
358
|
|
@@ -515,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
515
537
|
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
516
538
|
print(f"\nRunning decode output matching tests...")
|
517
539
|
|
518
|
-
for test_case in TEST_CASES["
|
540
|
+
for test_case in TEST_CASES["output_match"]:
|
519
541
|
with self.subTest(test_case=test_case["name"]):
|
520
542
|
print(f" Testing {test_case['name']}: {test_case['description']}")
|
521
543
|
|
@@ -833,7 +855,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
833
855
|
|
834
856
|
# Test workspace properties
|
835
857
|
self.assertEqual(metadata.workspace.device.type, "cuda")
|
836
|
-
self.assertEqual(metadata.workspace.dtype, torch.
|
858
|
+
self.assertEqual(metadata.workspace.dtype, torch.uint8)
|
837
859
|
self.assertGreater(
|
838
860
|
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
839
861
|
)
|
@@ -993,8 +1015,8 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
993
1015
|
)
|
994
1016
|
|
995
1017
|
# Verify CUDA graph buffers are allocated
|
996
|
-
self.assertIsNotNone(backend.
|
997
|
-
self.assertIsNotNone(backend.
|
1018
|
+
self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
|
1019
|
+
self.assertIsNotNone(backend.decode_cuda_graph_workspace)
|
998
1020
|
|
999
1021
|
# Test capture metadata
|
1000
1022
|
seq_lens = torch.full(
|
@@ -1090,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
1090
1112
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
1091
1113
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
1092
1114
|
|
1115
|
+
def test_prefill_output_match_self_attention(self):
|
1116
|
+
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
1117
|
+
print(f"\nRunning prefill output tests...")
|
1118
|
+
|
1119
|
+
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
1120
|
+
with self.subTest(test_case=test_case["name"]):
|
1121
|
+
print(
|
1122
|
+
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
config = self._merge_config(test_case)
|
1126
|
+
batch_size = config["batch_size"]
|
1127
|
+
max_seq_len = config["max_seq_len"]
|
1128
|
+
|
1129
|
+
# Create components
|
1130
|
+
(
|
1131
|
+
model_runner_trtllm,
|
1132
|
+
model_runner_reference,
|
1133
|
+
trtllm_backend,
|
1134
|
+
reference_backend,
|
1135
|
+
layer,
|
1136
|
+
) = self._create_model_components(config, is_prefill=True)
|
1137
|
+
|
1138
|
+
# Prefill uses full sequences
|
1139
|
+
seq_lens = torch.full(
|
1140
|
+
(batch_size,), max_seq_len, device=config["device"]
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
def _create_forward_batch_prefill(
|
1144
|
+
batch_size,
|
1145
|
+
seq_lens,
|
1146
|
+
extend_prefix_lens,
|
1147
|
+
backend,
|
1148
|
+
model_runner,
|
1149
|
+
config,
|
1150
|
+
):
|
1151
|
+
"""Create a forward batch for the given backend."""
|
1152
|
+
|
1153
|
+
fb = ForwardBatch(
|
1154
|
+
batch_size=batch_size,
|
1155
|
+
input_ids=torch.randint(
|
1156
|
+
0, 100, (batch_size, 1), device=config["device"]
|
1157
|
+
),
|
1158
|
+
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
1159
|
+
seq_lens_sum=int(seq_lens.sum().item()),
|
1160
|
+
extend_prefix_lens=extend_prefix_lens,
|
1161
|
+
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
1162
|
+
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
1163
|
+
.cpu()
|
1164
|
+
.int()
|
1165
|
+
.tolist(),
|
1166
|
+
forward_mode=ForwardMode.EXTEND,
|
1167
|
+
req_pool_indices=torch.arange(
|
1168
|
+
batch_size, device=config["device"]
|
1169
|
+
),
|
1170
|
+
seq_lens=seq_lens,
|
1171
|
+
seq_lens_cpu=seq_lens.cpu(),
|
1172
|
+
attn_attend_prefix_cache=False,
|
1173
|
+
mha_return_lse=False,
|
1174
|
+
attn_backend=backend,
|
1175
|
+
)
|
1176
|
+
fb.req_to_token_pool = model_runner.req_to_token_pool
|
1177
|
+
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
1178
|
+
|
1179
|
+
# Add position information for RoPE
|
1180
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
1181
|
+
|
1182
|
+
return fb
|
1183
|
+
|
1184
|
+
# Create forward batches
|
1185
|
+
fb_trtllm = _create_forward_batch_prefill(
|
1186
|
+
batch_size,
|
1187
|
+
seq_lens.clone(),
|
1188
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1189
|
+
trtllm_backend,
|
1190
|
+
model_runner_trtllm,
|
1191
|
+
config,
|
1192
|
+
)
|
1193
|
+
fb_reference = _create_forward_batch_prefill(
|
1194
|
+
batch_size,
|
1195
|
+
seq_lens.clone(),
|
1196
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1197
|
+
reference_backend,
|
1198
|
+
model_runner_reference,
|
1199
|
+
config,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
# Initialize metadata for both backends
|
1203
|
+
trtllm_backend.init_forward_metadata(fb_trtllm)
|
1204
|
+
reference_backend.init_forward_metadata(fb_reference)
|
1205
|
+
|
1206
|
+
# Create Q, K, V tensors for prefill
|
1207
|
+
torch.manual_seed(config["seed_qkv"])
|
1208
|
+
|
1209
|
+
def _create_qkv_tensors_prefill(
|
1210
|
+
batch_size, seq_len, config, dtype_override=None
|
1211
|
+
):
|
1212
|
+
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
1213
|
+
device = config["device"]
|
1214
|
+
dtype = dtype_override or config["dtype"]
|
1215
|
+
|
1216
|
+
total_tokens = batch_size * seq_len
|
1217
|
+
|
1218
|
+
tp_q_head_num = config["tp_q_head_num"]
|
1219
|
+
tp_k_head_num = config["tp_k_head_num"]
|
1220
|
+
head_dim = config["prefill_head_dim"]
|
1221
|
+
v_head_dim = config["prefill_v_head_dim"]
|
1222
|
+
|
1223
|
+
q = torch.randn(
|
1224
|
+
(total_tokens, tp_q_head_num * head_dim),
|
1225
|
+
dtype=dtype,
|
1226
|
+
device=device,
|
1227
|
+
)
|
1228
|
+
k = torch.randn(
|
1229
|
+
(total_tokens, tp_k_head_num * head_dim),
|
1230
|
+
dtype=dtype,
|
1231
|
+
device=device,
|
1232
|
+
)
|
1233
|
+
v = torch.randn(
|
1234
|
+
(total_tokens, tp_k_head_num * v_head_dim),
|
1235
|
+
dtype=dtype,
|
1236
|
+
device=device,
|
1237
|
+
)
|
1238
|
+
|
1239
|
+
# Reshape as requested
|
1240
|
+
q = q.view(-1, tp_q_head_num, head_dim)
|
1241
|
+
k = k.view(-1, tp_k_head_num, head_dim)
|
1242
|
+
v = v.view(-1, tp_k_head_num, v_head_dim)
|
1243
|
+
|
1244
|
+
return q, k, v
|
1245
|
+
|
1246
|
+
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
1247
|
+
# Run prefill on both backends
|
1248
|
+
out_trtllm = trtllm_backend.forward_extend(
|
1249
|
+
q, k, v, layer, fb_trtllm, False
|
1250
|
+
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1251
|
+
out_reference = reference_backend.forward_extend(
|
1252
|
+
q, k, v, layer, fb_reference, False
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
tolerance = config.get("tolerance", 1e-2)
|
1256
|
+
comparison_passed = compare_outputs(
|
1257
|
+
out_trtllm, out_reference, tolerance=tolerance
|
1258
|
+
)
|
1259
|
+
self.assertTrue(
|
1260
|
+
comparison_passed,
|
1261
|
+
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
1262
|
+
f"Config: {test_case['name']}, "
|
1263
|
+
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
1264
|
+
)
|
1265
|
+
|
1093
1266
|
|
1094
1267
|
if __name__ == "__main__":
|
1095
1268
|
unittest.main()
|
sglang/test/few_shot_gsm8k.py
CHANGED
sglang/test/runners.py
CHANGED
@@ -505,6 +505,7 @@ class SRTRunner:
|
|
505
505
|
mem_fraction_static: float = 0.65,
|
506
506
|
trust_remote_code: bool = False,
|
507
507
|
speculative_draft_model_path: Optional[str] = None,
|
508
|
+
speculative_draft_model_revision: Optional[str] = None,
|
508
509
|
speculative_algorithm: Optional[str] = None,
|
509
510
|
speculative_num_steps: Optional[int] = None,
|
510
511
|
speculative_eagle_topk: Optional[int] = None,
|
@@ -526,6 +527,9 @@ class SRTRunner:
|
|
526
527
|
spec_kwargs = {}
|
527
528
|
if speculative_draft_model_path:
|
528
529
|
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
530
|
+
spec_kwargs["speculative_draft_model_revision"] = (
|
531
|
+
speculative_draft_model_revision
|
532
|
+
)
|
529
533
|
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
530
534
|
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
531
535
|
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
sglang/test/test_cutlass_moe.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoConfig
|
|
9
9
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
10
10
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
11
11
|
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
12
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
@@ -21,7 +22,7 @@ def calc_diff(x, y):
|
|
21
22
|
|
22
23
|
def get_model_config(tp_size: int):
|
23
24
|
config = AutoConfig.from_pretrained(
|
24
|
-
"deepseek-ai/
|
25
|
+
"deepseek-ai/Deepseek-R1", trust_remote_code=True
|
25
26
|
)
|
26
27
|
E = config.n_routed_experts
|
27
28
|
topk = config.num_experts_per_tok
|
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
152
153
|
problem_sizes2,
|
153
154
|
)
|
154
155
|
|
156
|
+
topk_output = StandardTopKOutput(
|
157
|
+
topk_weights=topk_weights,
|
158
|
+
topk_ids=topk_ids,
|
159
|
+
router_logits=torch.randn(
|
160
|
+
(batch_size, topk), device=topk_weights.device, dtype=dtype
|
161
|
+
),
|
162
|
+
)
|
163
|
+
|
164
|
+
moe_runner_config = MoeRunnerConfig(
|
165
|
+
num_experts=E,
|
166
|
+
top_k=topk,
|
167
|
+
hidden_size=H,
|
168
|
+
intermediate_size_per_partition=I,
|
169
|
+
params_dtype=dtype,
|
170
|
+
activation="silu",
|
171
|
+
inplace=False,
|
172
|
+
)
|
173
|
+
|
155
174
|
# Note: Triton expects non-transposed weights
|
156
|
-
moe_config = MoeRunnerConfig(inplace=False)
|
157
175
|
triton_lambda = lambda: fused_experts(
|
158
176
|
x,
|
159
177
|
w1,
|
160
178
|
w2,
|
161
|
-
|
162
|
-
|
179
|
+
topk_output,
|
180
|
+
moe_runner_config,
|
163
181
|
use_fp8_w8a8=True,
|
164
182
|
w1_scale=w1_scale,
|
165
183
|
w2_scale=w2_scale,
|
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
224
242
|
x,
|
225
243
|
w1, # Original shape
|
226
244
|
w2, # Original shape
|
227
|
-
|
228
|
-
|
245
|
+
topk_output,
|
246
|
+
moe_runner_config,
|
229
247
|
use_fp8_w8a8=True,
|
230
248
|
w1_scale=w1_scale,
|
231
249
|
w2_scale=w2_scale,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
2
|
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Literal, Optional
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten
|
|
25
25
|
return packed_tensor.to(torch.int8)
|
26
26
|
|
27
27
|
|
28
|
-
def pack_interleave(num_experts, ref_weight, ref_scale):
|
28
|
+
def pack_interleave(num_experts, ref_weight, ref_scale, alignment=4):
|
29
29
|
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
30
30
|
|
31
31
|
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
|
|
33
33
|
w_q = w_q.contiguous()
|
34
34
|
|
35
35
|
scale_interleaved = ref_scale.reshape(
|
36
|
-
ref_scale.shape[0],
|
36
|
+
ref_scale.shape[0],
|
37
|
+
ref_scale.shape[1],
|
38
|
+
(ref_scale.shape[2] // alignment),
|
39
|
+
alignment,
|
37
40
|
) # [E, N, K/4, 4]
|
38
41
|
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
39
42
|
scale_interleaved = scale_interleaved.reshape(
|
40
|
-
ref_scale.shape[0],
|
43
|
+
ref_scale.shape[0],
|
44
|
+
ref_scale.shape[2] // alignment,
|
45
|
+
ref_scale.shape[1] * alignment,
|
41
46
|
) # [E, K/4, N*4]
|
42
47
|
w_scale = scale_interleaved.contiguous()
|
43
48
|
|
@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
|
|
48
53
|
@pytest.mark.parametrize("N", [2048])
|
49
54
|
@pytest.mark.parametrize("K", [7168])
|
50
55
|
@pytest.mark.parametrize("E", [256])
|
51
|
-
@pytest.mark.parametrize("
|
56
|
+
@pytest.mark.parametrize("tp_size", [8])
|
57
|
+
@pytest.mark.parametrize("use_ep_moe", [True, False])
|
52
58
|
@pytest.mark.parametrize("topk", [8])
|
53
59
|
@pytest.mark.parametrize("group_size", [128])
|
54
60
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
55
|
-
def test_cutlass_w4a8_moe(M, N, K, E,
|
56
|
-
|
61
|
+
def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dtype):
|
62
|
+
if use_ep_moe:
|
63
|
+
local_e = E // tp_size
|
64
|
+
else: # tp mode
|
65
|
+
local_e = E
|
66
|
+
N = N // tp_size
|
57
67
|
|
58
68
|
debug = False
|
59
69
|
if debug:
|
@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
|
87
97
|
)
|
88
98
|
|
89
99
|
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
90
|
-
|
100
|
+
if use_ep_moe:
|
101
|
+
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
102
|
+
else:
|
103
|
+
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2, 1)
|
91
104
|
|
92
105
|
device = "cuda"
|
93
106
|
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
@@ -265,7 +278,9 @@ def ref(
|
|
265
278
|
|
266
279
|
gate, fc1 = fc1.chunk(2, dim=-1)
|
267
280
|
fc1 = fc1 * torch.nn.functional.silu(gate)
|
268
|
-
act = (fc1 / pre_quant_scale_2.float()).to(
|
281
|
+
act = torch.clamp((fc1 / pre_quant_scale_2.float()), -448.0, 448.0).to(
|
282
|
+
torch.float8_e4m3fn
|
283
|
+
)
|
269
284
|
act = act.to(dtype)
|
270
285
|
|
271
286
|
w2 = ref_weight_2[e_idx]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import requests
|
4
|
+
|
5
|
+
from sglang.srt.utils import kill_process_tree
|
6
|
+
from sglang.test.test_utils import (
|
7
|
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
8
|
+
CustomTestCase,
|
9
|
+
popen_with_error_check,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class TestDisaggregationBase(CustomTestCase):
|
14
|
+
@classmethod
|
15
|
+
def setUpClass(cls):
|
16
|
+
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
|
17
|
+
pass
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def launch_lb(cls):
|
21
|
+
lb_command = [
|
22
|
+
"python3",
|
23
|
+
"-m",
|
24
|
+
"sglang_router.launch_router",
|
25
|
+
"--pd-disaggregation",
|
26
|
+
"--mini-lb", # FIXME: remove this
|
27
|
+
"--prefill",
|
28
|
+
cls.prefill_url,
|
29
|
+
"--decode",
|
30
|
+
cls.decode_url,
|
31
|
+
"--host",
|
32
|
+
cls.base_host,
|
33
|
+
"--port",
|
34
|
+
cls.lb_port,
|
35
|
+
]
|
36
|
+
print("Starting load balancer:", " ".join(lb_command))
|
37
|
+
cls.process_lb = popen_with_error_check(lb_command)
|
38
|
+
cls.wait_server_ready(cls.lb_url + "/health")
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
|
42
|
+
start_time = time.perf_counter()
|
43
|
+
while True:
|
44
|
+
try:
|
45
|
+
response = requests.get(url)
|
46
|
+
if response.status_code == 200:
|
47
|
+
print(f"Server {url} is ready")
|
48
|
+
return
|
49
|
+
except Exception:
|
50
|
+
pass
|
51
|
+
|
52
|
+
if time.perf_counter() - start_time > timeout:
|
53
|
+
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
54
|
+
time.sleep(1)
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def tearDownClass(cls):
|
58
|
+
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
59
|
+
if process:
|
60
|
+
try:
|
61
|
+
kill_process_tree(process.pid)
|
62
|
+
except Exception as e:
|
63
|
+
print(f"Error killing process {process.pid}: {e}")
|
64
|
+
|
65
|
+
# wait for 5 seconds
|
66
|
+
time.sleep(5)
|
sglang/test/test_utils.py
CHANGED
@@ -42,7 +42,8 @@ DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
|
42
42
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
43
43
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
|
44
44
|
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
45
|
-
|
45
|
+
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE = "Qwen/Qwen1.5-MoE-A2.7B"
|
46
|
+
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT = "Qwen/Qwen1.5-MoE-A2.7B-Chat"
|
46
47
|
|
47
48
|
# MLA test models
|
48
49
|
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
@@ -72,6 +73,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
|
|
72
73
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
73
74
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
74
75
|
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
76
|
+
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
|
77
|
+
"meta-llama/Llama-3.1-8B-Instruct"
|
78
|
+
)
|
79
|
+
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
75
80
|
|
76
81
|
# Other use cases
|
77
82
|
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
@@ -466,6 +471,25 @@ def try_cached_model(model_repo: str):
|
|
466
471
|
return model_dir if model_dir else model_repo
|
467
472
|
|
468
473
|
|
474
|
+
def popen_with_error_check(command: list[str], allow_exit: bool = False):
|
475
|
+
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
476
|
+
|
477
|
+
def _run_and_check():
|
478
|
+
stdout, stderr = process.communicate()
|
479
|
+
|
480
|
+
while process.poll() is None:
|
481
|
+
time.sleep(5)
|
482
|
+
|
483
|
+
if not allow_exit or process.returncode != 0:
|
484
|
+
raise Exception(
|
485
|
+
f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
|
486
|
+
)
|
487
|
+
|
488
|
+
t = threading.Thread(target=_run_and_check)
|
489
|
+
t.start()
|
490
|
+
return process
|
491
|
+
|
492
|
+
|
469
493
|
def popen_launch_server(
|
470
494
|
model: str,
|
471
495
|
base_url: str,
|
sglang/utils.py
CHANGED
@@ -457,6 +457,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
457
457
|
NOTE: Typically, the server runs in a separate terminal.
|
458
458
|
In this notebook, we run the server and notebook code together, so their outputs are combined.
|
459
459
|
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
|
460
|
+
To reduce the log length, we set the log level to warning for the server, the default log level is info.
|
460
461
|
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
|
461
462
|
"""
|
462
463
|
)
|
@@ -472,6 +473,10 @@ class TypeBasedDispatcher:
|
|
472
473
|
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
473
474
|
self._mapping = mapping
|
474
475
|
|
476
|
+
def __iadd__(self, other: "TypeBasedDispatcher"):
|
477
|
+
self._mapping.extend(other._mapping)
|
478
|
+
return self
|
479
|
+
|
475
480
|
def __call__(self, obj: Any):
|
476
481
|
for ty, fn in self._mapping:
|
477
482
|
if isinstance(obj, ty):
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.5.
|
1
|
+
__version__ = "0.5.2"
|