sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,14 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
+ import datetime
16
17
  import gc
17
18
  import json
18
19
  import logging
20
+ import os
19
21
  import time
20
- from typing import List, Optional, Tuple
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
21
24
 
22
25
  import torch
23
26
  import torch.distributed as dist
@@ -34,6 +37,7 @@ from sglang.srt.distributed import (
34
37
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
35
38
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
36
39
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
37
41
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
38
42
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
43
  from sglang.srt.layers.dp_attention import (
@@ -51,14 +55,18 @@ from sglang.srt.mem_cache.memory_pool import (
51
55
  MHATokenToKVPool,
52
56
  MLATokenToKVPool,
53
57
  ReqToTokenPool,
58
+ TokenToKVPoolAllocator,
54
59
  )
55
60
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
56
61
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
62
  from sglang.srt.model_loader import get_model
63
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
64
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
58
65
  from sglang.srt.server_args import ServerArgs
59
66
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
60
67
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
61
68
  from sglang.srt.utils import (
69
+ MultiprocessingSerializer,
62
70
  enable_show_time_cost,
63
71
  get_available_gpu_memory,
64
72
  init_custom_process_group,
@@ -69,10 +77,15 @@ from sglang.srt.utils import (
69
77
  set_cpu_offload_max_bytes,
70
78
  set_cuda_arch,
71
79
  )
80
+ from sglang.utils import get_exception_traceback
72
81
 
73
82
  logger = logging.getLogger(__name__)
74
83
 
75
84
 
85
+ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
86
+ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
87
+
88
+
76
89
  class ModelRunner:
77
90
  """ModelRunner runs the forward passes of the models."""
78
91
 
@@ -86,6 +99,8 @@ class ModelRunner:
86
99
  nccl_port: int,
87
100
  server_args: ServerArgs,
88
101
  is_draft_worker: bool = False,
102
+ req_to_token_pool: Optional[ReqToTokenPool] = None,
103
+ token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
89
104
  ):
90
105
  # Parse args
91
106
  self.model_config = model_config
@@ -103,68 +118,21 @@ class ModelRunner:
103
118
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
104
119
  server_args.speculative_algorithm
105
120
  )
121
+ self.req_to_token_pool = req_to_token_pool
122
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
106
123
 
107
124
  # Model-specific adjustment
108
- if (
109
- self.model_config.attention_arch == AttentionArch.MLA
110
- and not self.server_args.disable_mla
111
- ):
112
- # TODO: add MLA optimization on CPU
113
- if self.server_args.device != "cpu":
114
- if server_args.enable_flashinfer_mla:
115
- logger.info(
116
- "FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
117
- )
118
- self.server_args.attention_backend = "flashinfer"
119
- else:
120
- logger.info("MLA optimization is turned on. Use triton backend.")
121
- self.server_args.attention_backend = "triton"
122
-
123
- if self.server_args.enable_double_sparsity:
124
- logger.info(
125
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
126
- )
127
- self.server_args.attention_backend = "triton"
128
- self.server_args.disable_cuda_graph = True
129
- if self.server_args.ds_heavy_channel_type is None:
130
- raise ValueError(
131
- "Please specify the heavy channel type for double sparsity optimization."
132
- )
133
- self.init_double_sparsity_channel_config(
134
- self.server_args.ds_heavy_channel_type
135
- )
125
+ self.model_specific_adjustment()
136
126
 
137
- if self.is_multimodal:
138
- self.mem_fraction_static *= 0.95
139
- logger.info(
140
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
141
- f"because this is a multimodal model."
142
- )
143
-
144
- if self.model_config.hf_config.architectures == [
145
- "MllamaForConditionalGeneration"
146
- ]:
147
- logger.info("Automatically turn off --chunked-prefill-size for mllama.")
148
- server_args.chunked_prefill_size = -1
149
-
150
- if self.model_config.hf_config.architectures == [
151
- "Qwen2VLForConditionalGeneration"
152
- ]:
153
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
154
- logger.info(
155
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
156
- )
157
- server_args.chunked_prefill_size = -1
158
- server_args.disable_radix_cache = True
159
-
160
- # Global vars
161
127
  if server_args.show_time_cost:
162
128
  enable_show_time_cost()
129
+
163
130
  if server_args.disable_outlines_disk_cache:
164
131
  from outlines.caching import disable_cache
165
132
 
166
133
  disable_cache()
167
134
 
135
+ # Global vars
168
136
  global_server_args_dict.update(
169
137
  {
170
138
  "attention_backend": server_args.attention_backend,
@@ -176,11 +144,17 @@ class ModelRunner:
176
144
  "enable_dp_attention": server_args.enable_dp_attention,
177
145
  "enable_ep_moe": server_args.enable_ep_moe,
178
146
  "device": server_args.device,
147
+ "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
148
+ "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
179
149
  "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
180
150
  "disable_radix_cache": server_args.disable_radix_cache,
151
+ "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
152
+ "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
153
+ "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
181
154
  }
182
155
  )
183
156
 
157
+ # CPU offload
184
158
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
185
159
 
186
160
  # Get memory before model loading
@@ -210,9 +184,11 @@ class ModelRunner:
210
184
  else:
211
185
  self.torch_tp_applied = False
212
186
 
213
- # Init memory pool and attention backends
187
+ # Init lora
214
188
  if server_args.lora_paths is not None:
215
189
  self.init_lora_manager()
190
+
191
+ # Init memory pool and attention backends
216
192
  self.init_memory_pool(
217
193
  min_per_gpu_memory,
218
194
  server_args.max_running_requests,
@@ -226,6 +202,59 @@ class ModelRunner:
226
202
  self.cuda_graph_runner = None
227
203
  self.init_attention_backend()
228
204
 
205
+ def model_specific_adjustment(self):
206
+ server_args = self.server_args
207
+
208
+ if (
209
+ self.model_config.attention_arch == AttentionArch.MLA
210
+ and not server_args.disable_mla
211
+ ):
212
+ # TODO: add MLA optimization on CPU
213
+ if server_args.device != "cpu":
214
+ if server_args.enable_flashinfer_mla:
215
+ logger.info(
216
+ "MLA optimization is turned on. Use flashinfer mla backend."
217
+ )
218
+ server_args.attention_backend = "flashinfer_mla"
219
+ else:
220
+ logger.info("MLA optimization is turned on. Use triton backend.")
221
+ server_args.attention_backend = "triton"
222
+
223
+ if server_args.enable_double_sparsity:
224
+ logger.info(
225
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
226
+ )
227
+ server_args.attention_backend = "triton"
228
+ server_args.disable_cuda_graph = True
229
+ if server_args.ds_heavy_channel_type is None:
230
+ raise ValueError(
231
+ "Please specify the heavy channel type for double sparsity optimization."
232
+ )
233
+ self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
234
+
235
+ if self.is_multimodal:
236
+ self.mem_fraction_static *= 0.95
237
+ logger.info(
238
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
239
+ f"because this is a multimodal model."
240
+ )
241
+
242
+ if self.model_config.hf_config.architectures == [
243
+ "MllamaForConditionalGeneration"
244
+ ]:
245
+ logger.info("Automatically turn off --chunked-prefill-size for mllama.")
246
+ server_args.chunked_prefill_size = -1
247
+
248
+ if self.model_config.hf_config.architectures == [
249
+ "Qwen2VLForConditionalGeneration"
250
+ ]:
251
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
252
+ logger.info(
253
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
254
+ )
255
+ server_args.chunked_prefill_size = -1
256
+ server_args.disable_radix_cache = True
257
+
229
258
  def init_torch_distributed(self):
230
259
  logger.info("Init torch distributed begin.")
231
260
 
@@ -233,14 +262,13 @@ class ModelRunner:
233
262
  if self.device == "cuda":
234
263
  backend = "nccl"
235
264
  elif self.device == "xpu":
236
- # TODO(liangan1): Just use gloo to bypass the initilization fail
237
- # Need to use xccl for xpu backend in the future
238
- backend = "gloo"
265
+ backend = "xccl"
239
266
  elif self.device == "hpu":
240
267
  backend = "hccl"
241
268
  elif self.device == "cpu":
242
269
  backend = "gloo"
243
270
 
271
+ before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
244
272
  if not self.server_args.enable_p2p_check:
245
273
  monkey_patch_p2p_access_check()
246
274
 
@@ -258,6 +286,7 @@ class ModelRunner:
258
286
  rank=self.tp_rank,
259
287
  local_rank=self.gpu_id,
260
288
  distributed_init_method=dist_init_method,
289
+ timeout=self.server_args.dist_timeout,
261
290
  )
262
291
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
263
292
  initialize_dp_attention(
@@ -270,20 +299,24 @@ class ModelRunner:
270
299
  min_per_gpu_memory = get_available_gpu_memory(
271
300
  self.device, self.gpu_id, distributed=self.tp_size > 1
272
301
  )
302
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
273
303
  self.tp_group = get_tp_group()
274
304
  self.attention_tp_group = get_attention_tp_group()
275
305
 
276
306
  # Check memory for tensor parallelism
277
307
  if self.tp_size > 1:
278
- local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
279
308
  if min_per_gpu_memory < local_gpu_memory * 0.9:
280
309
  raise ValueError(
281
310
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
282
311
  )
283
312
 
313
+ logger.info(
314
+ f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
315
+ )
284
316
  return min_per_gpu_memory
285
317
 
286
318
  def load_model(self):
319
+ before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
287
320
  logger.info(
288
321
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
289
322
  )
@@ -353,13 +386,27 @@ class ModelRunner:
353
386
  )
354
387
  self.dtype = self.model_config.dtype
355
388
 
389
+ after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
356
390
  logger.info(
357
391
  f"Load weight end. "
358
392
  f"type={type(self.model).__name__}, "
359
393
  f"dtype={self.dtype}, "
360
- f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
394
+ f"avail mem={after_avail_memory:.2f} GB, "
395
+ f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
361
396
  )
362
397
 
398
+ # Handle the case where some ranks do not finish loading.
399
+ try:
400
+ dist.monitored_barrier(
401
+ group=get_tp_group().cpu_group,
402
+ timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
403
+ wait_all_ranks=True,
404
+ )
405
+ except RuntimeError:
406
+ raise ValueError(
407
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
408
+ ) from None
409
+
363
410
  def update_weights_from_disk(
364
411
  self, model_path: str, load_format: str
365
412
  ) -> tuple[bool, str]:
@@ -512,8 +559,21 @@ class ModelRunner:
512
559
  logger.error(error_msg)
513
560
  return False, error_msg
514
561
 
515
- def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
516
- self.model.load_weights(named_tensors)
562
+ def update_weights_from_tensor(
563
+ self,
564
+ named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
565
+ load_format: Optional[str] = None,
566
+ ):
567
+ named_tensors = [
568
+ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
569
+ for name, tensor in named_tensors
570
+ ]
571
+ if load_format == "direct":
572
+ _model_load_weights_direct(self.model, named_tensors)
573
+ elif load_format is None:
574
+ self.model.load_weights(named_tensors)
575
+ else:
576
+ raise NotImplementedError(f"Unknown load_format={load_format}")
517
577
  return True, "Success"
518
578
 
519
579
  def get_weights_by_name(
@@ -606,15 +666,31 @@ class ModelRunner:
606
666
  4096,
607
667
  )
608
668
 
669
+ if SGLANG_CI_SMALL_KV_SIZE:
670
+ self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
671
+
609
672
  if not self.spec_algorithm.is_none():
610
673
  if self.is_draft_worker:
611
674
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
675
+ max_num_reqs = self.server_args.max_num_reqs
612
676
  else:
677
+ # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
678
+ # can be concurrently allocated, so we should give a headroom for it.
613
679
  self.server_args.draft_runner_cache_size = (
614
680
  self.max_total_num_tokens
615
- + max_num_reqs * self.server_args.speculative_num_steps
681
+ # draft
682
+ + max_num_reqs
683
+ * self.server_args.speculative_num_steps
684
+ * self.server_args.speculative_eagle_topk
685
+ # verify
686
+ + max_num_reqs * self.server_args.speculative_num_draft_tokens
687
+ # buffer
616
688
  + 100
617
689
  )
690
+ # Target worker and draft worker shares the same indices for the
691
+ # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
692
+ self.max_total_num_tokens = self.server_args.draft_runner_cache_size
693
+ self.server_args.max_num_reqs = max_num_reqs
618
694
 
619
695
  if max_total_tokens is not None:
620
696
  if max_total_tokens > self.max_total_num_tokens:
@@ -630,12 +706,17 @@ class ModelRunner:
630
706
  "Not enough memory. Please try to increase --mem-fraction-static."
631
707
  )
632
708
 
633
- self.req_to_token_pool = ReqToTokenPool(
634
- size=max_num_reqs + 1,
635
- max_context_len=self.model_config.context_len + 4,
636
- device=self.device,
637
- enable_memory_saver=self.server_args.enable_memory_saver,
638
- )
709
+ if self.req_to_token_pool is None:
710
+ self.req_to_token_pool = ReqToTokenPool(
711
+ size=max_num_reqs + 1,
712
+ max_context_len=self.model_config.context_len + 4,
713
+ device=self.device,
714
+ enable_memory_saver=self.server_args.enable_memory_saver,
715
+ )
716
+ else:
717
+ # Draft worker shares req_to_token_pool with the target worker.
718
+ assert self.is_draft_worker
719
+
639
720
  if (
640
721
  self.model_config.attention_arch == AttentionArch.MLA
641
722
  and not self.server_args.disable_mla
@@ -670,6 +751,17 @@ class ModelRunner:
670
751
  device=self.device,
671
752
  enable_memory_saver=self.server_args.enable_memory_saver,
672
753
  )
754
+
755
+ if self.token_to_kv_pool_allocator is None:
756
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
757
+ self.max_total_num_tokens,
758
+ dtype=self.kv_cache_dtype,
759
+ device=self.device,
760
+ kvcache=self.token_to_kv_pool,
761
+ )
762
+ else:
763
+ assert self.is_draft_worker
764
+
673
765
  logger.info(
674
766
  f"Memory pool end. "
675
767
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -687,6 +779,10 @@ class ModelRunner:
687
779
  def init_attention_backend(self):
688
780
  """Init attention kernel backend."""
689
781
  if self.server_args.attention_backend == "flashinfer":
782
+ # Init streams
783
+ if self.server_args.speculative_algorithm == "EAGLE":
784
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
785
+
690
786
  self.attn_backend = FlashInferAttnBackend(self)
691
787
  elif self.server_args.attention_backend == "triton":
692
788
  assert self.sliding_window_size is None, (
@@ -703,6 +799,8 @@ class ModelRunner:
703
799
  self.attn_backend = TritonAttnBackend(self)
704
800
  elif self.server_args.attention_backend == "torch_native":
705
801
  self.attn_backend = TorchNativeAttnBackend(self)
802
+ elif self.server_args.attention_backend == "flashinfer_mla":
803
+ self.attn_backend = FlashInferMLAAttnBackend(self)
706
804
  else:
707
805
  raise ValueError(
708
806
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -737,9 +835,16 @@ class ModelRunner:
737
835
  return
738
836
 
739
837
  tic = time.time()
740
- logger.info("Capture cuda graph begin. This can take up to several minutes.")
838
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
839
+ logger.info(
840
+ f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
841
+ )
741
842
  self.cuda_graph_runner = CudaGraphRunner(self)
742
- logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
843
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
844
+ logger.info(
845
+ f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
846
+ f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
847
+ )
743
848
 
744
849
  def apply_torch_tp(self):
745
850
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -754,8 +859,12 @@ class ModelRunner:
754
859
  forward_batch.input_ids, forward_batch.positions, forward_batch
755
860
  )
756
861
 
757
- def forward_extend(self, forward_batch: ForwardBatch):
758
- self.attn_backend.init_forward_metadata(forward_batch)
862
+ def forward_extend(
863
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
864
+ ):
865
+ if not skip_attn_backend_init:
866
+ self.attn_backend.init_forward_metadata(forward_batch)
867
+
759
868
  if self.is_generation:
760
869
  if forward_batch.input_embeds is None:
761
870
  return self.model.forward(
@@ -782,28 +891,33 @@ class ModelRunner:
782
891
  forward_batch.input_ids, forward_batch.positions, forward_batch
783
892
  )
784
893
 
785
- def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
894
+ def forward(
895
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
896
+ ) -> LogitsProcessorOutput:
786
897
  if (
787
898
  forward_batch.forward_mode.is_cuda_graph()
788
899
  and self.cuda_graph_runner
789
900
  and self.cuda_graph_runner.can_run(forward_batch)
790
901
  ):
791
- return self.cuda_graph_runner.replay(forward_batch)
902
+ return self.cuda_graph_runner.replay(
903
+ forward_batch, skip_attn_backend_init=skip_attn_backend_init
904
+ )
792
905
 
793
906
  if forward_batch.forward_mode.is_decode():
794
907
  return self.forward_decode(forward_batch)
795
908
  elif forward_batch.forward_mode.is_extend():
796
- return self.forward_extend(forward_batch)
909
+ return self.forward_extend(
910
+ forward_batch, skip_attn_backend_init=skip_attn_backend_init
911
+ )
797
912
  elif forward_batch.forward_mode.is_idle():
798
913
  return self.forward_idle(forward_batch)
799
914
  else:
800
915
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
801
916
 
802
- def sample(
803
- self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
804
- ) -> torch.Tensor:
917
+ def _preprocess_logits(
918
+ self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
919
+ ):
805
920
  # Apply logit bias
806
- sampling_info = forward_batch.sampling_info
807
921
  if sampling_info.sampling_info_done:
808
922
  # Overlap mode: the function update_regex_vocab_mask was executed
809
923
  # in process_batch_result of the last batch.
@@ -812,15 +926,77 @@ class ModelRunner:
812
926
  else:
813
927
  # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
814
928
  sampling_info.update_regex_vocab_mask()
815
- sampling_info.update_penalties()
816
929
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
817
930
 
931
+ def update_output_logprobs(
932
+ self,
933
+ logits_output: LogitsProcessorOutput,
934
+ sampling_info: SamplingBatchInfo,
935
+ top_logprobs_nums: List[int],
936
+ token_ids_logprobs: List[int],
937
+ next_token_ids: torch.Tensor,
938
+ *,
939
+ num_tokens_per_req: List[int],
940
+ ):
941
+ """Update the logits_output's output logprob based on next_token_ids
942
+
943
+ Args:
944
+ logits_output: The logits output from the model forward
945
+ sampling_info: Sampling info for logprob calculation
946
+ top_logprobs_nums: Number of logprobs per request.
947
+ next_token_ids: Next token ids.
948
+ num_tokens_per_req: The number of tokens per request.
949
+
950
+ Returns:
951
+ A list of next_token_ids
952
+ """
953
+ self._preprocess_logits(logits_output, sampling_info)
954
+ # We should repeat top_logprobs_nums to match num_tokens_per_req.
955
+ top_logprobs_nums_repeat_interleaved = []
956
+ token_ids_logprobs_repeat_interleaved = []
957
+ for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
958
+ top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
959
+ for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
960
+ token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
961
+ self.sampler(
962
+ logits_output,
963
+ sampling_info,
964
+ True,
965
+ top_logprobs_nums_repeat_interleaved,
966
+ token_ids_logprobs_repeat_interleaved,
967
+ batch_next_token_ids=next_token_ids,
968
+ )
969
+
970
+ def sample(
971
+ self,
972
+ logits_output: LogitsProcessorOutput,
973
+ forward_batch: ForwardBatch,
974
+ ) -> torch.Tensor:
975
+ """Sample and compute logprobs and update logits_output.
976
+
977
+ Args:
978
+ logits_output: The logits output from the model forward
979
+ forward_batch: The forward batch that generates logits_output
980
+
981
+ Returns:
982
+ A list of next_token_ids
983
+ """
984
+ # For duplex models with multiple output streams.
985
+ if isinstance(logits_output, tuple):
986
+ return torch.stack(
987
+ [self.sample(values, forward_batch) for values in logits_output],
988
+ axis=-1,
989
+ )
990
+
991
+ self._preprocess_logits(logits_output, forward_batch.sampling_info)
992
+
818
993
  # Sample the next tokens
819
994
  next_token_ids = self.sampler(
820
995
  logits_output,
821
- sampling_info,
996
+ forward_batch.sampling_info,
822
997
  forward_batch.return_logprob,
823
998
  forward_batch.top_logprobs_nums,
999
+ forward_batch.token_ids_logprobs,
824
1000
  )
825
1001
  return next_token_ids
826
1002
 
@@ -832,3 +1008,26 @@ class ModelRunner:
832
1008
  if rope_scaling is None:
833
1009
  return False
834
1010
  return rope_scaling.get("type", None) == "mrope"
1011
+
1012
+
1013
+ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
1014
+ params_dict = dict(model.named_parameters())
1015
+ for name, tensor in named_tensors:
1016
+ default_weight_loader(params_dict[name], tensor)
1017
+
1018
+
1019
+ def _unwrap_tensor(tensor, tp_rank):
1020
+ if isinstance(tensor, LocalSerializedTensor):
1021
+ return tensor.get(tp_rank)
1022
+ return tensor
1023
+
1024
+
1025
+ @dataclass
1026
+ class LocalSerializedTensor:
1027
+ """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
1028
+ The i-th element in the list corresponds to i-th rank's GPU."""
1029
+
1030
+ values: List[bytes]
1031
+
1032
+ def get(self, rank: int):
1033
+ return MultiprocessingSerializer.deserialize(self.values[rank])
@@ -11,7 +11,7 @@ import math
11
11
  import os
12
12
  from abc import ABC, abstractmethod
13
13
  from contextlib import contextmanager
14
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
14
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
15
15
 
16
16
  import gguf
17
17
  import huggingface_hub
@@ -19,7 +19,7 @@ import numpy as np
19
19
  import torch
20
20
  from huggingface_hub import HfApi, hf_hub_download
21
21
  from torch import nn
22
- from transformers import AutoModelForCausalLM, PretrainedConfig
22
+ from transformers import AutoModelForCausalLM
23
23
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
24
 
25
25
  from sglang.srt.configs.device_config import DeviceConfig
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
197
197
 
198
198
  Returns the path to the downloaded model, or None if the model is not
199
199
  downloaded from ModelScope."""
200
- if "SGLANG_USE_MODELSCOPE" in os.environ:
200
+ if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
201
201
  # download model from ModelScope hub,
202
202
  # lazy import so that modelscope is not required for normal use.
203
203
  # pylint: disable=C.