sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
- from typing import TYPE_CHECKING, List, Optional
34
+ from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
36
  import torch
37
37
  import triton
@@ -41,12 +41,13 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
41
  from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
- from sglang.srt.layers.attention import AttentionBackend
44
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
45
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
46
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
47
  from sglang.srt.model_executor.model_runner import ModelRunner
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
- from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
49
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
50
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
50
51
 
51
52
 
52
53
  class ForwardMode(IntEnum):
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
112
113
 
113
114
  class CaptureHiddenMode(IntEnum):
114
115
  NULL = auto()
116
+ # Capture hidden states of all tokens.
115
117
  FULL = auto()
118
+ # Capture a hidden state of the last token.
116
119
  LAST = auto()
117
120
 
118
121
  def need_capture(self):
@@ -148,10 +151,14 @@ class ForwardBatch:
148
151
  # For logprob
149
152
  return_logprob: bool = False
150
153
  top_logprobs_nums: Optional[List[int]] = None
154
+ token_ids_logprobs: Optional[List[List[int]]] = None
151
155
 
152
156
  # Position information
153
157
  positions: torch.Tensor = None
154
158
 
159
+ # For decode
160
+ decode_seq_lens_cpu: Optional[torch.Tensor] = None
161
+
155
162
  # For extend
156
163
  extend_num_tokens: Optional[int] = None
157
164
  extend_seq_lens: Optional[torch.Tensor] = None
@@ -160,6 +167,7 @@ class ForwardBatch:
160
167
  extend_prefix_lens_cpu: Optional[List[int]] = None
161
168
  extend_seq_lens_cpu: Optional[List[int]] = None
162
169
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
170
+ extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
163
171
 
164
172
  # For multimodal
165
173
  image_inputs: Optional[List[ImageInputs]] = None
@@ -185,15 +193,27 @@ class ForwardBatch:
185
193
  attn_backend: AttentionBackend = None
186
194
 
187
195
  # For DP attention
188
- global_num_tokens: Optional[List[int]] = None
196
+ global_num_tokens_cpu: Optional[List[int]] = None
197
+ global_num_tokens_gpu: Optional[torch.Tensor] = None
198
+ # Has to be None when cuda graph is captured.
199
+ global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
200
+ global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
201
+ # for extend, local start pos and num tokens is different in logits processor
202
+ # this will be computed in get_dp_local_info
203
+ # this will be recomputed in LogitsMetadata.from_forward_batch
204
+ dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
205
+ dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
189
206
  gathered_buffer: Optional[torch.Tensor] = None
190
207
  can_run_dp_cuda_graph: bool = False
191
208
 
192
209
  # Speculative decoding
193
- spec_info: SpecInfo = None
210
+ spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
194
211
  spec_algorithm: SpeculativeAlgorithm = None
195
212
  capture_hidden_mode: CaptureHiddenMode = None
196
213
 
214
+ # For padding
215
+ padded_static_len: int = -1 # -1 if not padded
216
+
197
217
  # For Qwen2-VL
198
218
  mrope_positions: torch.Tensor = None
199
219
 
@@ -203,8 +223,13 @@ class ForwardBatch:
203
223
  batch: ModelWorkerBatch,
204
224
  model_runner: ModelRunner,
205
225
  ):
206
-
207
226
  device = model_runner.device
227
+ extend_input_logprob_token_ids_gpu = None
228
+ if batch.extend_input_logprob_token_ids is not None:
229
+ extend_input_logprob_token_ids_gpu = (
230
+ batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
231
+ )
232
+
208
233
  ret = cls(
209
234
  forward_mode=batch.forward_mode,
210
235
  batch_size=len(batch.seq_lens),
@@ -220,7 +245,7 @@ class ForwardBatch:
220
245
  seq_lens_sum=batch.seq_lens_sum,
221
246
  return_logprob=batch.return_logprob,
222
247
  top_logprobs_nums=batch.top_logprobs_nums,
223
- global_num_tokens=batch.global_num_tokens,
248
+ token_ids_logprobs=batch.token_ids_logprobs,
224
249
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
225
250
  lora_paths=batch.lora_paths,
226
251
  sampling_info=batch.sampling_info,
@@ -231,10 +256,12 @@ class ForwardBatch:
231
256
  spec_info=batch.spec_info,
232
257
  capture_hidden_mode=batch.capture_hidden_mode,
233
258
  input_embeds=batch.input_embeds,
259
+ extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
234
260
  )
235
261
 
236
- if ret.global_num_tokens is not None:
237
- max_len = max(ret.global_num_tokens)
262
+ if batch.global_num_tokens is not None:
263
+ ret.global_num_tokens_cpu = batch.global_num_tokens
264
+ max_len = max(ret.global_num_tokens_cpu)
238
265
  ret.gathered_buffer = torch.zeros(
239
266
  (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
240
267
  dtype=model_runner.dtype,
@@ -256,6 +283,8 @@ class ForwardBatch:
256
283
  if ret.forward_mode.is_decode():
257
284
  if ret.positions is None:
258
285
  ret.positions = clamp_position(batch.seq_lens)
286
+ if ret.decode_seq_lens_cpu is None:
287
+ ret.decode_seq_lens_cpu = batch.decode_seq_lens
259
288
  else:
260
289
  ret.extend_seq_lens = torch.tensor(
261
290
  batch.extend_seq_lens, dtype=torch.int32
@@ -263,13 +292,12 @@ class ForwardBatch:
263
292
  ret.extend_prefix_lens = torch.tensor(
264
293
  batch.extend_prefix_lens, dtype=torch.int32
265
294
  ).to(device, non_blocking=True)
266
- if (
267
- model_runner.server_args.attention_backend != "torch_native"
268
- and model_runner.server_args.speculative_algorithm != "NEXTN"
269
- ):
295
+ if model_runner.server_args.attention_backend != "torch_native":
270
296
  ret.extend_num_tokens = batch.extend_num_tokens
271
297
  positions, ret.extend_start_loc = compute_position_triton(
272
- ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
298
+ ret.extend_prefix_lens,
299
+ ret.extend_seq_lens,
300
+ ret.extend_num_tokens,
273
301
  )
274
302
  else:
275
303
  positions, ret.extend_start_loc = compute_position_torch(
@@ -341,6 +369,7 @@ class ForwardBatch:
341
369
  )
342
370
  batch.image_inputs[i].mrope_position_delta = mrope_position_delta
343
371
  mrope_positions_list[i] = mrope_positions
372
+
344
373
  self.mrope_positions = torch.concat(
345
374
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
346
375
  axis=1,
@@ -353,6 +382,8 @@ def compute_position_triton(
353
382
  ):
354
383
  """Compute positions. It is a fused version of `compute_position_torch`."""
355
384
  batch_size = extend_seq_lens.shape[0]
385
+ has_prefix = extend_prefix_lens.shape[0] == batch_size
386
+
356
387
  positions = torch.empty(
357
388
  extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
358
389
  )
@@ -366,6 +397,7 @@ def compute_position_triton(
366
397
  extend_start_loc,
367
398
  extend_prefix_lens,
368
399
  extend_seq_lens,
400
+ has_prefix,
369
401
  )
370
402
 
371
403
  return positions, extend_start_loc
@@ -377,11 +409,12 @@ def compute_position_kernel(
377
409
  extend_start_loc,
378
410
  extend_prefix_lens,
379
411
  extend_seq_lens,
412
+ has_prefix: tl.constexpr,
380
413
  ):
381
414
  BLOCK_SIZE: tl.constexpr = 512
382
- pid = tl.program_id(0)
415
+ pid = tl.program_id(0).to(tl.int64)
383
416
 
384
- prefix_len = tl.load(extend_prefix_lens + pid)
417
+ prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
385
418
  seq_len = tl.load(extend_seq_lens + pid)
386
419
 
387
420
  # TODO: optimize this?
@@ -13,11 +13,14 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
+ import datetime
16
17
  import gc
17
18
  import json
18
19
  import logging
20
+ import os
19
21
  import time
20
- from typing import List, Optional, Tuple
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
21
24
 
22
25
  import torch
23
26
  import torch.distributed as dist
@@ -34,6 +37,7 @@ from sglang.srt.distributed import (
34
37
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
35
38
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
36
39
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
37
41
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
38
42
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
43
  from sglang.srt.layers.dp_attention import (
@@ -51,14 +55,18 @@ from sglang.srt.mem_cache.memory_pool import (
51
55
  MHATokenToKVPool,
52
56
  MLATokenToKVPool,
53
57
  ReqToTokenPool,
58
+ TokenToKVPoolAllocator,
54
59
  )
55
60
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
56
61
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
62
  from sglang.srt.model_loader import get_model
63
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
64
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
58
65
  from sglang.srt.server_args import ServerArgs
59
66
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
60
67
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
61
68
  from sglang.srt.utils import (
69
+ MultiprocessingSerializer,
62
70
  enable_show_time_cost,
63
71
  get_available_gpu_memory,
64
72
  init_custom_process_group,
@@ -69,10 +77,15 @@ from sglang.srt.utils import (
69
77
  set_cpu_offload_max_bytes,
70
78
  set_cuda_arch,
71
79
  )
80
+ from sglang.utils import get_exception_traceback
72
81
 
73
82
  logger = logging.getLogger(__name__)
74
83
 
75
84
 
85
+ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
86
+ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
87
+
88
+
76
89
  class ModelRunner:
77
90
  """ModelRunner runs the forward passes of the models."""
78
91
 
@@ -86,6 +99,8 @@ class ModelRunner:
86
99
  nccl_port: int,
87
100
  server_args: ServerArgs,
88
101
  is_draft_worker: bool = False,
102
+ req_to_token_pool: Optional[ReqToTokenPool] = None,
103
+ token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
89
104
  ):
90
105
  # Parse args
91
106
  self.model_config = model_config
@@ -103,6 +118,8 @@ class ModelRunner:
103
118
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
104
119
  server_args.speculative_algorithm
105
120
  )
121
+ self.req_to_token_pool = req_to_token_pool
122
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
106
123
 
107
124
  # Model-specific adjustment
108
125
  if (
@@ -113,9 +130,9 @@ class ModelRunner:
113
130
  if self.server_args.device != "cpu":
114
131
  if server_args.enable_flashinfer_mla:
115
132
  logger.info(
116
- "FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
133
+ "MLA optimization is turned on. Use flashinfer mla backend."
117
134
  )
118
- self.server_args.attention_backend = "flashinfer"
135
+ self.server_args.attention_backend = "flashinfer_mla"
119
136
  else:
120
137
  logger.info("MLA optimization is turned on. Use triton backend.")
121
138
  self.server_args.attention_backend = "triton"
@@ -176,7 +193,13 @@ class ModelRunner:
176
193
  "enable_dp_attention": server_args.enable_dp_attention,
177
194
  "enable_ep_moe": server_args.enable_ep_moe,
178
195
  "device": server_args.device,
196
+ "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
197
+ "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
179
198
  "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
199
+ "disable_radix_cache": server_args.disable_radix_cache,
200
+ "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
201
+ "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
202
+ "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
180
203
  }
181
204
  )
182
205
 
@@ -193,6 +216,18 @@ class ModelRunner:
193
216
  self.sampler = Sampler()
194
217
  self.load_model()
195
218
 
219
+ # Handle the case where some of models don't finish loading.
220
+ try:
221
+ dist.monitored_barrier(
222
+ group=get_tp_group().cpu_group,
223
+ timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
224
+ wait_all_ranks=True,
225
+ )
226
+ except RuntimeError:
227
+ raise ValueError(
228
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
229
+ ) from None
230
+
196
231
  # Apply torchao quantization
197
232
  torchao_applied = getattr(self.model, "torchao_applied", False)
198
233
  # In layered loading, torchao may have been applied
@@ -227,19 +262,18 @@ class ModelRunner:
227
262
 
228
263
  def init_torch_distributed(self):
229
264
  logger.info("Init torch distributed begin.")
230
-
231
265
  torch.get_device_module(self.device).set_device(self.gpu_id)
266
+
232
267
  if self.device == "cuda":
233
268
  backend = "nccl"
234
269
  elif self.device == "xpu":
235
- # TODO(liangan1): Just use gloo to bypass the initilization fail
236
- # Need to use xccl for xpu backend in the future
237
- backend = "gloo"
270
+ backend = "xccl"
238
271
  elif self.device == "hpu":
239
272
  backend = "hccl"
240
273
  elif self.device == "cpu":
241
274
  backend = "gloo"
242
275
 
276
+ before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
243
277
  if not self.server_args.enable_p2p_check:
244
278
  monkey_patch_p2p_access_check()
245
279
 
@@ -257,6 +291,7 @@ class ModelRunner:
257
291
  rank=self.tp_rank,
258
292
  local_rank=self.gpu_id,
259
293
  distributed_init_method=dist_init_method,
294
+ timeout=self.server_args.dist_timeout,
260
295
  )
261
296
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
262
297
  initialize_dp_attention(
@@ -269,20 +304,24 @@ class ModelRunner:
269
304
  min_per_gpu_memory = get_available_gpu_memory(
270
305
  self.device, self.gpu_id, distributed=self.tp_size > 1
271
306
  )
307
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
272
308
  self.tp_group = get_tp_group()
273
309
  self.attention_tp_group = get_attention_tp_group()
274
310
 
275
311
  # Check memory for tensor parallelism
276
312
  if self.tp_size > 1:
277
- local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
278
313
  if min_per_gpu_memory < local_gpu_memory * 0.9:
279
314
  raise ValueError(
280
315
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
281
316
  )
282
317
 
318
+ logger.info(
319
+ f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
320
+ )
283
321
  return min_per_gpu_memory
284
322
 
285
323
  def load_model(self):
324
+ before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
286
325
  logger.info(
287
326
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
288
327
  )
@@ -352,11 +391,13 @@ class ModelRunner:
352
391
  )
353
392
  self.dtype = self.model_config.dtype
354
393
 
394
+ after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
355
395
  logger.info(
356
396
  f"Load weight end. "
357
397
  f"type={type(self.model).__name__}, "
358
398
  f"dtype={self.dtype}, "
359
- f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
399
+ f"avail mem={after_avail_memory:.2f} GB, "
400
+ f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
360
401
  )
361
402
 
362
403
  def update_weights_from_disk(
@@ -511,8 +552,21 @@ class ModelRunner:
511
552
  logger.error(error_msg)
512
553
  return False, error_msg
513
554
 
514
- def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
515
- self.model.load_weights(named_tensors)
555
+ def update_weights_from_tensor(
556
+ self,
557
+ named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
558
+ load_format: Optional[str] = None,
559
+ ):
560
+ named_tensors = [
561
+ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
562
+ for name, tensor in named_tensors
563
+ ]
564
+ if load_format == "direct":
565
+ _model_load_weights_direct(self.model, named_tensors)
566
+ elif load_format is None:
567
+ self.model.load_weights(named_tensors)
568
+ else:
569
+ raise NotImplementedError(f"Unknown load_format={load_format}")
516
570
  return True, "Success"
517
571
 
518
572
  def get_weights_by_name(
@@ -605,15 +659,31 @@ class ModelRunner:
605
659
  4096,
606
660
  )
607
661
 
662
+ if SGLANG_CI_SMALL_KV_SIZE:
663
+ self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
664
+
608
665
  if not self.spec_algorithm.is_none():
609
666
  if self.is_draft_worker:
610
667
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
668
+ max_num_reqs = self.server_args.max_num_reqs
611
669
  else:
670
+ # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
671
+ # can be concurrently allocated, so we should give a headroom for it.
612
672
  self.server_args.draft_runner_cache_size = (
613
673
  self.max_total_num_tokens
614
- + max_num_reqs * self.server_args.speculative_num_steps
674
+ # draft
675
+ + max_num_reqs
676
+ * self.server_args.speculative_num_steps
677
+ * self.server_args.speculative_eagle_topk
678
+ # verify
679
+ + max_num_reqs * self.server_args.speculative_num_draft_tokens
680
+ # buffer
615
681
  + 100
616
682
  )
683
+ # Target worker and draft worker shares the same indices for the
684
+ # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
685
+ self.max_total_num_tokens = self.server_args.draft_runner_cache_size
686
+ self.server_args.max_num_reqs = max_num_reqs
617
687
 
618
688
  if max_total_tokens is not None:
619
689
  if max_total_tokens > self.max_total_num_tokens:
@@ -629,12 +699,26 @@ class ModelRunner:
629
699
  "Not enough memory. Please try to increase --mem-fraction-static."
630
700
  )
631
701
 
632
- self.req_to_token_pool = ReqToTokenPool(
633
- size=max_num_reqs + 1,
634
- max_context_len=self.model_config.context_len + 4,
635
- device=self.device,
636
- enable_memory_saver=self.server_args.enable_memory_saver,
637
- )
702
+ if self.req_to_token_pool is None:
703
+ self.req_to_token_pool = ReqToTokenPool(
704
+ size=max_num_reqs + 1,
705
+ max_context_len=self.model_config.context_len + 4,
706
+ device=self.device,
707
+ enable_memory_saver=self.server_args.enable_memory_saver,
708
+ )
709
+ else:
710
+ # Draft worker shares req_to_token_pool with the target worker.
711
+ assert self.is_draft_worker
712
+
713
+ if self.token_to_kv_pool_allocator is None:
714
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
715
+ self.max_total_num_tokens,
716
+ dtype=self.kv_cache_dtype,
717
+ device=self.device,
718
+ )
719
+ else:
720
+ assert self.is_draft_worker
721
+
638
722
  if (
639
723
  self.model_config.attention_arch == AttentionArch.MLA
640
724
  and not self.server_args.disable_mla
@@ -702,6 +786,8 @@ class ModelRunner:
702
786
  self.attn_backend = TritonAttnBackend(self)
703
787
  elif self.server_args.attention_backend == "torch_native":
704
788
  self.attn_backend = TorchNativeAttnBackend(self)
789
+ elif self.server_args.attention_backend == "flashinfer_mla":
790
+ self.attn_backend = FlashInferMLAAttnBackend(self)
705
791
  else:
706
792
  raise ValueError(
707
793
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -736,9 +822,16 @@ class ModelRunner:
736
822
  return
737
823
 
738
824
  tic = time.time()
739
- logger.info("Capture cuda graph begin. This can take up to several minutes.")
825
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
826
+ logger.info(
827
+ f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
828
+ )
740
829
  self.cuda_graph_runner = CudaGraphRunner(self)
741
- logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
830
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
831
+ logger.info(
832
+ f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
833
+ f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
834
+ )
742
835
 
743
836
  def apply_torch_tp(self):
744
837
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -753,8 +846,12 @@ class ModelRunner:
753
846
  forward_batch.input_ids, forward_batch.positions, forward_batch
754
847
  )
755
848
 
756
- def forward_extend(self, forward_batch: ForwardBatch):
757
- self.attn_backend.init_forward_metadata(forward_batch)
849
+ def forward_extend(
850
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
851
+ ):
852
+ if not skip_attn_backend_init:
853
+ self.attn_backend.init_forward_metadata(forward_batch)
854
+
758
855
  if self.is_generation:
759
856
  if forward_batch.input_embeds is None:
760
857
  return self.model.forward(
@@ -798,11 +895,10 @@ class ModelRunner:
798
895
  else:
799
896
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
800
897
 
801
- def sample(
802
- self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
803
- ) -> torch.Tensor:
898
+ def _preprocess_logits(
899
+ self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
900
+ ):
804
901
  # Apply logit bias
805
- sampling_info = forward_batch.sampling_info
806
902
  if sampling_info.sampling_info_done:
807
903
  # Overlap mode: the function update_regex_vocab_mask was executed
808
904
  # in process_batch_result of the last batch.
@@ -811,15 +907,77 @@ class ModelRunner:
811
907
  else:
812
908
  # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
813
909
  sampling_info.update_regex_vocab_mask()
814
- sampling_info.update_penalties()
815
910
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
816
911
 
912
+ def update_output_logprobs(
913
+ self,
914
+ logits_output: LogitsProcessorOutput,
915
+ sampling_info: SamplingBatchInfo,
916
+ top_logprobs_nums: List[int],
917
+ token_ids_logprobs: List[int],
918
+ next_token_ids: torch.Tensor,
919
+ *,
920
+ num_tokens_per_req: List[int],
921
+ ):
922
+ """Update the logits_output's output logprob based on next_token_ids
923
+
924
+ Args:
925
+ logits_output: The logits output from the model forward
926
+ sampling_info: Sampling info for logprob calculation
927
+ top_logprobs_nums: Number of logprobs per request.
928
+ next_token_ids: Next token ids.
929
+ num_tokens_per_req: The number of tokens per request.
930
+
931
+ Returns:
932
+ A list of next_token_ids
933
+ """
934
+ self._preprocess_logits(logits_output, sampling_info)
935
+ # We should repeat top_logprobs_nums to match num_tokens_per_req.
936
+ top_logprobs_nums_repeat_interleaved = []
937
+ token_ids_logprobs_repeat_interleaved = []
938
+ for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
939
+ top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
940
+ for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
941
+ token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
942
+ self.sampler(
943
+ logits_output,
944
+ sampling_info,
945
+ True,
946
+ top_logprobs_nums_repeat_interleaved,
947
+ token_ids_logprobs_repeat_interleaved,
948
+ batch_next_token_ids=next_token_ids,
949
+ )
950
+
951
+ def sample(
952
+ self,
953
+ logits_output: LogitsProcessorOutput,
954
+ forward_batch: ForwardBatch,
955
+ ) -> torch.Tensor:
956
+ """Sample and compute logprobs and update logits_output.
957
+
958
+ Args:
959
+ logits_output: The logits output from the model forward
960
+ forward_batch: The forward batch that generates logits_output
961
+
962
+ Returns:
963
+ A list of next_token_ids
964
+ """
965
+ # For duplex models with multiple output streams.
966
+ if isinstance(logits_output, tuple):
967
+ return torch.stack(
968
+ [self.sample(values, forward_batch) for values in logits_output],
969
+ axis=-1,
970
+ )
971
+
972
+ self._preprocess_logits(logits_output, forward_batch.sampling_info)
973
+
817
974
  # Sample the next tokens
818
975
  next_token_ids = self.sampler(
819
976
  logits_output,
820
- sampling_info,
977
+ forward_batch.sampling_info,
821
978
  forward_batch.return_logprob,
822
979
  forward_batch.top_logprobs_nums,
980
+ forward_batch.token_ids_logprobs,
823
981
  )
824
982
  return next_token_ids
825
983
 
@@ -831,3 +989,26 @@ class ModelRunner:
831
989
  if rope_scaling is None:
832
990
  return False
833
991
  return rope_scaling.get("type", None) == "mrope"
992
+
993
+
994
+ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
995
+ params_dict = dict(model.named_parameters())
996
+ for name, tensor in named_tensors:
997
+ default_weight_loader(params_dict[name], tensor)
998
+
999
+
1000
+ def _unwrap_tensor(tensor, tp_rank):
1001
+ if isinstance(tensor, LocalSerializedTensor):
1002
+ return tensor.get(tp_rank)
1003
+ return tensor
1004
+
1005
+
1006
+ @dataclass
1007
+ class LocalSerializedTensor:
1008
+ """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
1009
+ The i-th element in the list corresponds to i-th rank's GPU."""
1010
+
1011
+ values: List[bytes]
1012
+
1013
+ def get(self, rank: int):
1014
+ return MultiprocessingSerializer.deserialize(self.values[rank])
@@ -11,7 +11,7 @@ import math
11
11
  import os
12
12
  from abc import ABC, abstractmethod
13
13
  from contextlib import contextmanager
14
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
14
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
15
15
 
16
16
  import gguf
17
17
  import huggingface_hub
@@ -19,7 +19,7 @@ import numpy as np
19
19
  import torch
20
20
  from huggingface_hub import HfApi, hf_hub_download
21
21
  from torch import nn
22
- from transformers import AutoModelForCausalLM, PretrainedConfig
22
+ from transformers import AutoModelForCausalLM
23
23
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
24
 
25
25
  from sglang.srt.configs.device_config import DeviceConfig
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
197
197
 
198
198
  Returns the path to the downloaded model, or None if the model is not
199
199
  downloaded from ModelScope."""
200
- if "SGLANG_USE_MODELSCOPE" in os.environ:
200
+ if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
201
201
  # download model from ModelScope hub,
202
202
  # lazy import so that modelscope is not required for normal use.
203
203
  # pylint: disable=C.