sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
- from typing import TYPE_CHECKING, List, Optional
34
+ from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
36
  import torch
37
37
  import triton
@@ -41,12 +41,13 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
41
  from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
- from sglang.srt.layers.attention import AttentionBackend
44
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
45
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
46
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
47
  from sglang.srt.model_executor.model_runner import ModelRunner
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
- from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
49
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
50
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
50
51
 
51
52
 
52
53
  class ForwardMode(IntEnum):
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
112
113
 
113
114
  class CaptureHiddenMode(IntEnum):
114
115
  NULL = auto()
116
+ # Capture hidden states of all tokens.
115
117
  FULL = auto()
118
+ # Capture a hidden state of the last token.
116
119
  LAST = auto()
117
120
 
118
121
  def need_capture(self):
@@ -148,10 +151,14 @@ class ForwardBatch:
148
151
  # For logprob
149
152
  return_logprob: bool = False
150
153
  top_logprobs_nums: Optional[List[int]] = None
154
+ token_ids_logprobs: Optional[List[List[int]]] = None
151
155
 
152
156
  # Position information
153
157
  positions: torch.Tensor = None
154
158
 
159
+ # For decode
160
+ decode_seq_lens_cpu: Optional[torch.Tensor] = None
161
+
155
162
  # For extend
156
163
  extend_num_tokens: Optional[int] = None
157
164
  extend_seq_lens: Optional[torch.Tensor] = None
@@ -160,6 +167,7 @@ class ForwardBatch:
160
167
  extend_prefix_lens_cpu: Optional[List[int]] = None
161
168
  extend_seq_lens_cpu: Optional[List[int]] = None
162
169
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
170
+ extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
163
171
 
164
172
  # For multimodal
165
173
  image_inputs: Optional[List[ImageInputs]] = None
@@ -185,15 +193,27 @@ class ForwardBatch:
185
193
  attn_backend: AttentionBackend = None
186
194
 
187
195
  # For DP attention
188
- global_num_tokens: Optional[List[int]] = None
196
+ global_num_tokens_cpu: Optional[List[int]] = None
197
+ global_num_tokens_gpu: Optional[torch.Tensor] = None
198
+ # Has to be None when cuda graph is captured.
199
+ global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
200
+ global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
201
+ # for extend, local start pos and num tokens is different in logits processor
202
+ # this will be computed in get_dp_local_info
203
+ # this will be recomputed in LogitsMetadata.from_forward_batch
204
+ dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
205
+ dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
189
206
  gathered_buffer: Optional[torch.Tensor] = None
190
207
  can_run_dp_cuda_graph: bool = False
191
208
 
192
209
  # Speculative decoding
193
- spec_info: SpecInfo = None
210
+ spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
194
211
  spec_algorithm: SpeculativeAlgorithm = None
195
212
  capture_hidden_mode: CaptureHiddenMode = None
196
213
 
214
+ # For padding
215
+ padded_static_len: int = -1 # -1 if not padded
216
+
197
217
  # For Qwen2-VL
198
218
  mrope_positions: torch.Tensor = None
199
219
 
@@ -203,8 +223,13 @@ class ForwardBatch:
203
223
  batch: ModelWorkerBatch,
204
224
  model_runner: ModelRunner,
205
225
  ):
206
-
207
226
  device = model_runner.device
227
+ extend_input_logprob_token_ids_gpu = None
228
+ if batch.extend_input_logprob_token_ids is not None:
229
+ extend_input_logprob_token_ids_gpu = (
230
+ batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
231
+ )
232
+
208
233
  ret = cls(
209
234
  forward_mode=batch.forward_mode,
210
235
  batch_size=len(batch.seq_lens),
@@ -220,7 +245,7 @@ class ForwardBatch:
220
245
  seq_lens_sum=batch.seq_lens_sum,
221
246
  return_logprob=batch.return_logprob,
222
247
  top_logprobs_nums=batch.top_logprobs_nums,
223
- global_num_tokens=batch.global_num_tokens,
248
+ token_ids_logprobs=batch.token_ids_logprobs,
224
249
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
225
250
  lora_paths=batch.lora_paths,
226
251
  sampling_info=batch.sampling_info,
@@ -231,10 +256,12 @@ class ForwardBatch:
231
256
  spec_info=batch.spec_info,
232
257
  capture_hidden_mode=batch.capture_hidden_mode,
233
258
  input_embeds=batch.input_embeds,
259
+ extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
234
260
  )
235
261
 
236
- if ret.global_num_tokens is not None:
237
- max_len = max(ret.global_num_tokens)
262
+ if batch.global_num_tokens is not None:
263
+ ret.global_num_tokens_cpu = batch.global_num_tokens
264
+ max_len = max(ret.global_num_tokens_cpu)
238
265
  ret.gathered_buffer = torch.zeros(
239
266
  (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
240
267
  dtype=model_runner.dtype,
@@ -256,6 +283,8 @@ class ForwardBatch:
256
283
  if ret.forward_mode.is_decode():
257
284
  if ret.positions is None:
258
285
  ret.positions = clamp_position(batch.seq_lens)
286
+ if ret.decode_seq_lens_cpu is None:
287
+ ret.decode_seq_lens_cpu = batch.decode_seq_lens
259
288
  else:
260
289
  ret.extend_seq_lens = torch.tensor(
261
290
  batch.extend_seq_lens, dtype=torch.int32
@@ -263,13 +292,12 @@ class ForwardBatch:
263
292
  ret.extend_prefix_lens = torch.tensor(
264
293
  batch.extend_prefix_lens, dtype=torch.int32
265
294
  ).to(device, non_blocking=True)
266
- if (
267
- model_runner.server_args.attention_backend != "torch_native"
268
- and model_runner.server_args.speculative_algorithm != "NEXTN"
269
- ):
295
+ if model_runner.server_args.attention_backend != "torch_native":
270
296
  ret.extend_num_tokens = batch.extend_num_tokens
271
297
  positions, ret.extend_start_loc = compute_position_triton(
272
- ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
298
+ ret.extend_prefix_lens,
299
+ ret.extend_seq_lens,
300
+ ret.extend_num_tokens,
273
301
  )
274
302
  else:
275
303
  positions, ret.extend_start_loc = compute_position_torch(
@@ -341,6 +369,7 @@ class ForwardBatch:
341
369
  )
342
370
  batch.image_inputs[i].mrope_position_delta = mrope_position_delta
343
371
  mrope_positions_list[i] = mrope_positions
372
+
344
373
  self.mrope_positions = torch.concat(
345
374
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
346
375
  axis=1,
@@ -353,6 +382,8 @@ def compute_position_triton(
353
382
  ):
354
383
  """Compute positions. It is a fused version of `compute_position_torch`."""
355
384
  batch_size = extend_seq_lens.shape[0]
385
+ has_prefix = extend_prefix_lens.shape[0] == batch_size
386
+
356
387
  positions = torch.empty(
357
388
  extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
358
389
  )
@@ -366,6 +397,7 @@ def compute_position_triton(
366
397
  extend_start_loc,
367
398
  extend_prefix_lens,
368
399
  extend_seq_lens,
400
+ has_prefix,
369
401
  )
370
402
 
371
403
  return positions, extend_start_loc
@@ -377,11 +409,12 @@ def compute_position_kernel(
377
409
  extend_start_loc,
378
410
  extend_prefix_lens,
379
411
  extend_seq_lens,
412
+ has_prefix: tl.constexpr,
380
413
  ):
381
414
  BLOCK_SIZE: tl.constexpr = 512
382
- pid = tl.program_id(0)
415
+ pid = tl.program_id(0).to(tl.int64)
383
416
 
384
- prefix_len = tl.load(extend_prefix_lens + pid)
417
+ prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
385
418
  seq_len = tl.load(extend_seq_lens + pid)
386
419
 
387
420
  # TODO: optimize this?
@@ -13,11 +13,14 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
+ import datetime
16
17
  import gc
17
18
  import json
18
19
  import logging
20
+ import os
19
21
  import time
20
- from typing import List, Optional, Tuple
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
21
24
 
22
25
  import torch
23
26
  import torch.distributed as dist
@@ -34,6 +37,7 @@ from sglang.srt.distributed import (
34
37
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
35
38
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
36
39
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
37
41
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
38
42
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
43
  from sglang.srt.layers.dp_attention import (
@@ -51,14 +55,18 @@ from sglang.srt.mem_cache.memory_pool import (
51
55
  MHATokenToKVPool,
52
56
  MLATokenToKVPool,
53
57
  ReqToTokenPool,
58
+ TokenToKVPoolAllocator,
54
59
  )
55
60
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
56
61
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
62
  from sglang.srt.model_loader import get_model
63
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
64
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
58
65
  from sglang.srt.server_args import ServerArgs
59
66
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
60
67
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
61
68
  from sglang.srt.utils import (
69
+ MultiprocessingSerializer,
62
70
  enable_show_time_cost,
63
71
  get_available_gpu_memory,
64
72
  init_custom_process_group,
@@ -69,10 +77,15 @@ from sglang.srt.utils import (
69
77
  set_cpu_offload_max_bytes,
70
78
  set_cuda_arch,
71
79
  )
80
+ from sglang.utils import get_exception_traceback
72
81
 
73
82
  logger = logging.getLogger(__name__)
74
83
 
75
84
 
85
+ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
86
+ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
87
+
88
+
76
89
  class ModelRunner:
77
90
  """ModelRunner runs the forward passes of the models."""
78
91
 
@@ -86,6 +99,8 @@ class ModelRunner:
86
99
  nccl_port: int,
87
100
  server_args: ServerArgs,
88
101
  is_draft_worker: bool = False,
102
+ req_to_token_pool: Optional[ReqToTokenPool] = None,
103
+ token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
89
104
  ):
90
105
  # Parse args
91
106
  self.model_config = model_config
@@ -103,6 +118,8 @@ class ModelRunner:
103
118
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
104
119
  server_args.speculative_algorithm
105
120
  )
121
+ self.req_to_token_pool = req_to_token_pool
122
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
106
123
 
107
124
  # Model-specific adjustment
108
125
  if (
@@ -113,9 +130,9 @@ class ModelRunner:
113
130
  if self.server_args.device != "cpu":
114
131
  if server_args.enable_flashinfer_mla:
115
132
  logger.info(
116
- "FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
133
+ "MLA optimization is turned on. Use flashinfer mla backend."
117
134
  )
118
- self.server_args.attention_backend = "flashinfer"
135
+ self.server_args.attention_backend = "flashinfer_mla"
119
136
  else:
120
137
  logger.info("MLA optimization is turned on. Use triton backend.")
121
138
  self.server_args.attention_backend = "triton"
@@ -176,8 +193,13 @@ class ModelRunner:
176
193
  "enable_dp_attention": server_args.enable_dp_attention,
177
194
  "enable_ep_moe": server_args.enable_ep_moe,
178
195
  "device": server_args.device,
196
+ "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
197
+ "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
179
198
  "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
180
199
  "disable_radix_cache": server_args.disable_radix_cache,
200
+ "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
201
+ "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
202
+ "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
181
203
  }
182
204
  )
183
205
 
@@ -194,6 +216,18 @@ class ModelRunner:
194
216
  self.sampler = Sampler()
195
217
  self.load_model()
196
218
 
219
+ # Handle the case where some of models don't finish loading.
220
+ try:
221
+ dist.monitored_barrier(
222
+ group=get_tp_group().cpu_group,
223
+ timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
224
+ wait_all_ranks=True,
225
+ )
226
+ except RuntimeError:
227
+ raise ValueError(
228
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
229
+ ) from None
230
+
197
231
  # Apply torchao quantization
198
232
  torchao_applied = getattr(self.model, "torchao_applied", False)
199
233
  # In layered loading, torchao may have been applied
@@ -228,19 +262,18 @@ class ModelRunner:
228
262
 
229
263
  def init_torch_distributed(self):
230
264
  logger.info("Init torch distributed begin.")
231
-
232
265
  torch.get_device_module(self.device).set_device(self.gpu_id)
266
+
233
267
  if self.device == "cuda":
234
268
  backend = "nccl"
235
269
  elif self.device == "xpu":
236
- # TODO(liangan1): Just use gloo to bypass the initilization fail
237
- # Need to use xccl for xpu backend in the future
238
- backend = "gloo"
270
+ backend = "xccl"
239
271
  elif self.device == "hpu":
240
272
  backend = "hccl"
241
273
  elif self.device == "cpu":
242
274
  backend = "gloo"
243
275
 
276
+ before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
244
277
  if not self.server_args.enable_p2p_check:
245
278
  monkey_patch_p2p_access_check()
246
279
 
@@ -258,6 +291,7 @@ class ModelRunner:
258
291
  rank=self.tp_rank,
259
292
  local_rank=self.gpu_id,
260
293
  distributed_init_method=dist_init_method,
294
+ timeout=self.server_args.dist_timeout,
261
295
  )
262
296
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
263
297
  initialize_dp_attention(
@@ -270,20 +304,24 @@ class ModelRunner:
270
304
  min_per_gpu_memory = get_available_gpu_memory(
271
305
  self.device, self.gpu_id, distributed=self.tp_size > 1
272
306
  )
307
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
273
308
  self.tp_group = get_tp_group()
274
309
  self.attention_tp_group = get_attention_tp_group()
275
310
 
276
311
  # Check memory for tensor parallelism
277
312
  if self.tp_size > 1:
278
- local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
279
313
  if min_per_gpu_memory < local_gpu_memory * 0.9:
280
314
  raise ValueError(
281
315
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
282
316
  )
283
317
 
318
+ logger.info(
319
+ f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
320
+ )
284
321
  return min_per_gpu_memory
285
322
 
286
323
  def load_model(self):
324
+ before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
287
325
  logger.info(
288
326
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
289
327
  )
@@ -353,11 +391,13 @@ class ModelRunner:
353
391
  )
354
392
  self.dtype = self.model_config.dtype
355
393
 
394
+ after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
356
395
  logger.info(
357
396
  f"Load weight end. "
358
397
  f"type={type(self.model).__name__}, "
359
398
  f"dtype={self.dtype}, "
360
- f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
399
+ f"avail mem={after_avail_memory:.2f} GB, "
400
+ f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
361
401
  )
362
402
 
363
403
  def update_weights_from_disk(
@@ -512,8 +552,21 @@ class ModelRunner:
512
552
  logger.error(error_msg)
513
553
  return False, error_msg
514
554
 
515
- def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
516
- self.model.load_weights(named_tensors)
555
+ def update_weights_from_tensor(
556
+ self,
557
+ named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
558
+ load_format: Optional[str] = None,
559
+ ):
560
+ named_tensors = [
561
+ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
562
+ for name, tensor in named_tensors
563
+ ]
564
+ if load_format == "direct":
565
+ _model_load_weights_direct(self.model, named_tensors)
566
+ elif load_format is None:
567
+ self.model.load_weights(named_tensors)
568
+ else:
569
+ raise NotImplementedError(f"Unknown load_format={load_format}")
517
570
  return True, "Success"
518
571
 
519
572
  def get_weights_by_name(
@@ -606,15 +659,31 @@ class ModelRunner:
606
659
  4096,
607
660
  )
608
661
 
662
+ if SGLANG_CI_SMALL_KV_SIZE:
663
+ self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
664
+
609
665
  if not self.spec_algorithm.is_none():
610
666
  if self.is_draft_worker:
611
667
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
668
+ max_num_reqs = self.server_args.max_num_reqs
612
669
  else:
670
+ # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
671
+ # can be concurrently allocated, so we should give a headroom for it.
613
672
  self.server_args.draft_runner_cache_size = (
614
673
  self.max_total_num_tokens
615
- + max_num_reqs * self.server_args.speculative_num_steps
674
+ # draft
675
+ + max_num_reqs
676
+ * self.server_args.speculative_num_steps
677
+ * self.server_args.speculative_eagle_topk
678
+ # verify
679
+ + max_num_reqs * self.server_args.speculative_num_draft_tokens
680
+ # buffer
616
681
  + 100
617
682
  )
683
+ # Target worker and draft worker shares the same indices for the
684
+ # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
685
+ self.max_total_num_tokens = self.server_args.draft_runner_cache_size
686
+ self.server_args.max_num_reqs = max_num_reqs
618
687
 
619
688
  if max_total_tokens is not None:
620
689
  if max_total_tokens > self.max_total_num_tokens:
@@ -630,12 +699,26 @@ class ModelRunner:
630
699
  "Not enough memory. Please try to increase --mem-fraction-static."
631
700
  )
632
701
 
633
- self.req_to_token_pool = ReqToTokenPool(
634
- size=max_num_reqs + 1,
635
- max_context_len=self.model_config.context_len + 4,
636
- device=self.device,
637
- enable_memory_saver=self.server_args.enable_memory_saver,
638
- )
702
+ if self.req_to_token_pool is None:
703
+ self.req_to_token_pool = ReqToTokenPool(
704
+ size=max_num_reqs + 1,
705
+ max_context_len=self.model_config.context_len + 4,
706
+ device=self.device,
707
+ enable_memory_saver=self.server_args.enable_memory_saver,
708
+ )
709
+ else:
710
+ # Draft worker shares req_to_token_pool with the target worker.
711
+ assert self.is_draft_worker
712
+
713
+ if self.token_to_kv_pool_allocator is None:
714
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
715
+ self.max_total_num_tokens,
716
+ dtype=self.kv_cache_dtype,
717
+ device=self.device,
718
+ )
719
+ else:
720
+ assert self.is_draft_worker
721
+
639
722
  if (
640
723
  self.model_config.attention_arch == AttentionArch.MLA
641
724
  and not self.server_args.disable_mla
@@ -703,6 +786,8 @@ class ModelRunner:
703
786
  self.attn_backend = TritonAttnBackend(self)
704
787
  elif self.server_args.attention_backend == "torch_native":
705
788
  self.attn_backend = TorchNativeAttnBackend(self)
789
+ elif self.server_args.attention_backend == "flashinfer_mla":
790
+ self.attn_backend = FlashInferMLAAttnBackend(self)
706
791
  else:
707
792
  raise ValueError(
708
793
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -737,9 +822,16 @@ class ModelRunner:
737
822
  return
738
823
 
739
824
  tic = time.time()
740
- logger.info("Capture cuda graph begin. This can take up to several minutes.")
825
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
826
+ logger.info(
827
+ f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
828
+ )
741
829
  self.cuda_graph_runner = CudaGraphRunner(self)
742
- logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
830
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
831
+ logger.info(
832
+ f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
833
+ f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
834
+ )
743
835
 
744
836
  def apply_torch_tp(self):
745
837
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -754,8 +846,12 @@ class ModelRunner:
754
846
  forward_batch.input_ids, forward_batch.positions, forward_batch
755
847
  )
756
848
 
757
- def forward_extend(self, forward_batch: ForwardBatch):
758
- self.attn_backend.init_forward_metadata(forward_batch)
849
+ def forward_extend(
850
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
851
+ ):
852
+ if not skip_attn_backend_init:
853
+ self.attn_backend.init_forward_metadata(forward_batch)
854
+
759
855
  if self.is_generation:
760
856
  if forward_batch.input_embeds is None:
761
857
  return self.model.forward(
@@ -799,11 +895,10 @@ class ModelRunner:
799
895
  else:
800
896
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
801
897
 
802
- def sample(
803
- self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
804
- ) -> torch.Tensor:
898
+ def _preprocess_logits(
899
+ self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
900
+ ):
805
901
  # Apply logit bias
806
- sampling_info = forward_batch.sampling_info
807
902
  if sampling_info.sampling_info_done:
808
903
  # Overlap mode: the function update_regex_vocab_mask was executed
809
904
  # in process_batch_result of the last batch.
@@ -812,15 +907,77 @@ class ModelRunner:
812
907
  else:
813
908
  # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
814
909
  sampling_info.update_regex_vocab_mask()
815
- sampling_info.update_penalties()
816
910
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
817
911
 
912
+ def update_output_logprobs(
913
+ self,
914
+ logits_output: LogitsProcessorOutput,
915
+ sampling_info: SamplingBatchInfo,
916
+ top_logprobs_nums: List[int],
917
+ token_ids_logprobs: List[int],
918
+ next_token_ids: torch.Tensor,
919
+ *,
920
+ num_tokens_per_req: List[int],
921
+ ):
922
+ """Update the logits_output's output logprob based on next_token_ids
923
+
924
+ Args:
925
+ logits_output: The logits output from the model forward
926
+ sampling_info: Sampling info for logprob calculation
927
+ top_logprobs_nums: Number of logprobs per request.
928
+ next_token_ids: Next token ids.
929
+ num_tokens_per_req: The number of tokens per request.
930
+
931
+ Returns:
932
+ A list of next_token_ids
933
+ """
934
+ self._preprocess_logits(logits_output, sampling_info)
935
+ # We should repeat top_logprobs_nums to match num_tokens_per_req.
936
+ top_logprobs_nums_repeat_interleaved = []
937
+ token_ids_logprobs_repeat_interleaved = []
938
+ for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
939
+ top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
940
+ for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
941
+ token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
942
+ self.sampler(
943
+ logits_output,
944
+ sampling_info,
945
+ True,
946
+ top_logprobs_nums_repeat_interleaved,
947
+ token_ids_logprobs_repeat_interleaved,
948
+ batch_next_token_ids=next_token_ids,
949
+ )
950
+
951
+ def sample(
952
+ self,
953
+ logits_output: LogitsProcessorOutput,
954
+ forward_batch: ForwardBatch,
955
+ ) -> torch.Tensor:
956
+ """Sample and compute logprobs and update logits_output.
957
+
958
+ Args:
959
+ logits_output: The logits output from the model forward
960
+ forward_batch: The forward batch that generates logits_output
961
+
962
+ Returns:
963
+ A list of next_token_ids
964
+ """
965
+ # For duplex models with multiple output streams.
966
+ if isinstance(logits_output, tuple):
967
+ return torch.stack(
968
+ [self.sample(values, forward_batch) for values in logits_output],
969
+ axis=-1,
970
+ )
971
+
972
+ self._preprocess_logits(logits_output, forward_batch.sampling_info)
973
+
818
974
  # Sample the next tokens
819
975
  next_token_ids = self.sampler(
820
976
  logits_output,
821
- sampling_info,
977
+ forward_batch.sampling_info,
822
978
  forward_batch.return_logprob,
823
979
  forward_batch.top_logprobs_nums,
980
+ forward_batch.token_ids_logprobs,
824
981
  )
825
982
  return next_token_ids
826
983
 
@@ -832,3 +989,26 @@ class ModelRunner:
832
989
  if rope_scaling is None:
833
990
  return False
834
991
  return rope_scaling.get("type", None) == "mrope"
992
+
993
+
994
+ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
995
+ params_dict = dict(model.named_parameters())
996
+ for name, tensor in named_tensors:
997
+ default_weight_loader(params_dict[name], tensor)
998
+
999
+
1000
+ def _unwrap_tensor(tensor, tp_rank):
1001
+ if isinstance(tensor, LocalSerializedTensor):
1002
+ return tensor.get(tp_rank)
1003
+ return tensor
1004
+
1005
+
1006
+ @dataclass
1007
+ class LocalSerializedTensor:
1008
+ """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
1009
+ The i-th element in the list corresponds to i-th rank's GPU."""
1010
+
1011
+ values: List[bytes]
1012
+
1013
+ def get(self, rank: int):
1014
+ return MultiprocessingSerializer.deserialize(self.values[rank])
@@ -11,7 +11,7 @@ import math
11
11
  import os
12
12
  from abc import ABC, abstractmethod
13
13
  from contextlib import contextmanager
14
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
14
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
15
15
 
16
16
  import gguf
17
17
  import huggingface_hub
@@ -19,7 +19,7 @@ import numpy as np
19
19
  import torch
20
20
  from huggingface_hub import HfApi, hf_hub_download
21
21
  from torch import nn
22
- from transformers import AutoModelForCausalLM, PretrainedConfig
22
+ from transformers import AutoModelForCausalLM
23
23
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
24
 
25
25
  from sglang.srt.configs.device_config import DeviceConfig
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
197
197
 
198
198
  Returns the path to the downloaded model, or None if the model is not
199
199
  downloaded from ModelScope."""
200
- if "SGLANG_USE_MODELSCOPE" in os.environ:
200
+ if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
201
201
  # download model from ModelScope hub,
202
202
  # lazy import so that modelscope is not required for normal use.
203
203
  # pylint: disable=C.