sglang 0.4.3.post4__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +1 -3
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +32 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +124 -665
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +6 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +78 -17
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/model_executor/cuda_graph_runner.py +9 -4
  95. sglang/srt/model_executor/forward_batch_info.py +12 -8
  96. sglang/srt/model_executor/model_runner.py +63 -63
  97. sglang/srt/model_loader/loader.py +2 -1
  98. sglang/srt/model_loader/weight_utils.py +1 -1
  99. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  100. sglang/srt/models/deepseek_nextn.py +23 -3
  101. sglang/srt/models/deepseek_v2.py +25 -19
  102. sglang/srt/models/minicpmv.py +28 -89
  103. sglang/srt/models/mllama.py +1 -1
  104. sglang/srt/models/qwen2.py +0 -1
  105. sglang/srt/models/qwen2_5_vl.py +25 -50
  106. sglang/srt/models/qwen2_vl.py +33 -49
  107. sglang/srt/openai_api/adapter.py +37 -15
  108. sglang/srt/openai_api/protocol.py +8 -1
  109. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  110. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  111. sglang/srt/server_args.py +19 -11
  112. sglang/srt/speculative/eagle_worker.py +75 -39
  113. sglang/srt/utils.py +104 -9
  114. sglang/test/runners.py +104 -10
  115. sglang/test/test_block_fp8.py +106 -16
  116. sglang/test/test_custom_ops.py +88 -0
  117. sglang/test/test_utils.py +20 -4
  118. sglang/utils.py +0 -4
  119. sglang/version.py +1 -1
  120. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -10
  121. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/RECORD +124 -79
  122. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  123. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  124. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -38,12 +38,12 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import get_compiler_backend
41
+ from sglang.srt.utils import get_compiler_backend, next_power_of_2
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
45
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
+ from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
47
47
  from sglang.srt.model_executor.model_runner import ModelRunner
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
49
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
51
51
 
52
52
 
53
53
  class ForwardMode(IntEnum):
54
- # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
55
- PREFILL = auto()
56
54
  # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
55
+ # It is also called "prefill" in common terminology.
57
56
  EXTEND = auto()
58
57
  # Decode one token.
59
58
  DECODE = auto()
@@ -153,6 +152,12 @@ class ForwardBatch:
153
152
  top_logprobs_nums: Optional[List[int]] = None
154
153
  token_ids_logprobs: Optional[List[List[int]]] = None
155
154
 
155
+ # For logits and logprobs post processing
156
+ temp_scaled_logprobs: bool = False
157
+ temperature: torch.Tensor = None
158
+ top_p_normalized_logprobs: bool = False
159
+ top_p: torch.Tensor = None
160
+
156
161
  # Position information
157
162
  positions: torch.Tensor = None
158
163
 
@@ -189,7 +194,7 @@ class ForwardBatch:
189
194
 
190
195
  # Attention backend
191
196
  req_to_token_pool: ReqToTokenPool = None
192
- token_to_kv_pool: BaseTokenToKVPool = None
197
+ token_to_kv_pool: KVCache = None
193
198
  attn_backend: AttentionBackend = None
194
199
 
195
200
  # For DP attention
@@ -229,7 +234,6 @@ class ForwardBatch:
229
234
  extend_input_logprob_token_ids_gpu = (
230
235
  batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
231
236
  )
232
-
233
237
  ret = cls(
234
238
  forward_mode=batch.forward_mode,
235
239
  batch_size=len(batch.seq_lens),
@@ -417,8 +421,8 @@ def compute_position_kernel(
417
421
  prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
418
422
  seq_len = tl.load(extend_seq_lens + pid)
419
423
 
420
- # TODO: optimize this?
421
- cumsum_start = 0
424
+ # NOTE: This can be slow for large bs
425
+ cumsum_start = tl.cast(0, tl.int64)
422
426
  for i in range(pid):
423
427
  cumsum_start += tl.load(extend_seq_lens + i)
424
428
 
@@ -35,17 +35,13 @@ from sglang.srt.distributed import (
35
35
  set_custom_all_reduce,
36
36
  )
37
37
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
38
- from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
39
- from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
- from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
41
- from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
42
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
43
38
  from sglang.srt.layers.dp_attention import (
44
39
  get_attention_tp_group,
45
40
  get_attention_tp_size,
46
41
  initialize_dp_attention,
47
42
  )
48
43
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
+ from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
49
45
  from sglang.srt.layers.sampler import Sampler
50
46
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
51
47
  from sglang.srt.lora.lora_manager import LoRAManager
@@ -57,9 +53,16 @@ from sglang.srt.mem_cache.memory_pool import (
57
53
  ReqToTokenPool,
58
54
  TokenToKVPoolAllocator,
59
55
  )
56
+ from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
60
57
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
61
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
59
  from sglang.srt.model_loader import get_model
60
+ from sglang.srt.model_loader.loader import (
61
+ DefaultModelLoader,
62
+ device_loading_context,
63
+ get_model_loader,
64
+ )
65
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
63
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
64
67
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
65
68
  from sglang.srt.server_args import ServerArgs
@@ -77,11 +80,9 @@ from sglang.srt.utils import (
77
80
  set_cpu_offload_max_bytes,
78
81
  set_cuda_arch,
79
82
  )
80
- from sglang.utils import get_exception_traceback
81
83
 
82
84
  logger = logging.getLogger(__name__)
83
85
 
84
-
85
86
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
86
87
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
87
88
 
@@ -118,6 +119,7 @@ class ModelRunner:
118
119
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
119
120
  server_args.speculative_algorithm
120
121
  )
122
+ self.page_size = server_args.page_size
121
123
  self.req_to_token_pool = req_to_token_pool
122
124
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
123
125
 
@@ -160,6 +162,11 @@ class ModelRunner:
160
162
  # Get memory before model loading
161
163
  min_per_gpu_memory = self.init_torch_distributed()
162
164
 
165
+ # If it is a draft model tp_group can be different.
166
+ self.initialize(min_per_gpu_memory)
167
+
168
+ def initialize(self, min_per_gpu_memory: float):
169
+ server_args = self.server_args
163
170
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
164
171
  enable=self.server_args.enable_memory_saver
165
172
  )
@@ -299,15 +306,16 @@ class ModelRunner:
299
306
  min_per_gpu_memory = get_available_gpu_memory(
300
307
  self.device, self.gpu_id, distributed=self.tp_size > 1
301
308
  )
302
- local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
303
309
  self.tp_group = get_tp_group()
304
310
  self.attention_tp_group = get_attention_tp_group()
305
311
 
306
312
  # Check memory for tensor parallelism
313
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
307
314
  if self.tp_size > 1:
308
315
  if min_per_gpu_memory < local_gpu_memory * 0.9:
309
316
  raise ValueError(
310
- "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
317
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
318
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
311
319
  )
312
320
 
313
321
  logger.info(
@@ -347,6 +355,8 @@ class ModelRunner:
347
355
  # Load the model
348
356
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
349
357
  monkey_patch_vllm_parallel_state()
358
+ monkey_patch_isinstance_for_vllm_base_layer()
359
+
350
360
  with self.memory_saver_adapter.region():
351
361
  self.model = get_model(
352
362
  model_config=self.model_config,
@@ -354,6 +364,7 @@ class ModelRunner:
354
364
  device_config=DeviceConfig(self.device),
355
365
  )
356
366
  monkey_patch_vllm_parallel_state(reverse=True)
367
+ monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
357
368
 
358
369
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
359
370
  if self.server_args.quantization_param_path is not None:
@@ -411,13 +422,6 @@ class ModelRunner:
411
422
  self, model_path: str, load_format: str
412
423
  ) -> tuple[bool, str]:
413
424
  """Update engine weights in-place from the disk."""
414
- from sglang.srt.model_loader.loader import (
415
- DefaultModelLoader,
416
- device_loading_context,
417
- get_model_loader,
418
- )
419
- from sglang.srt.model_loader.utils import set_default_torch_dtype
420
-
421
425
  logger.info(
422
426
  f"Update engine weights online from disk begin. "
423
427
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -427,7 +431,7 @@ class ModelRunner:
427
431
  self.model_config.model_path = model_path
428
432
  load_config = LoadConfig(load_format=load_format)
429
433
 
430
- # Only support vllm DefaultModelLoader for now
434
+ # Only support DefaultModelLoader for now
431
435
  loader = get_model_loader(load_config)
432
436
  if not isinstance(loader, DefaultModelLoader):
433
437
  message = f"Failed to get model loader: {loader}."
@@ -701,6 +705,12 @@ class ModelRunner:
701
705
  )
702
706
  self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
703
707
 
708
+ self.max_total_num_tokens = (
709
+ self.max_total_num_tokens
710
+ // self.server_args.page_size
711
+ * self.server_args.page_size
712
+ )
713
+
704
714
  if self.max_total_num_tokens <= 0:
705
715
  raise RuntimeError(
706
716
  "Not enough memory. Please try to increase --mem-fraction-static."
@@ -723,6 +733,7 @@ class ModelRunner:
723
733
  ):
724
734
  self.token_to_kv_pool = MLATokenToKVPool(
725
735
  self.max_total_num_tokens,
736
+ page_size=self.page_size,
726
737
  dtype=self.kv_cache_dtype,
727
738
  kv_lora_rank=self.model_config.kv_lora_rank,
728
739
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
@@ -733,6 +744,7 @@ class ModelRunner:
733
744
  elif self.server_args.enable_double_sparsity:
734
745
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
735
746
  self.max_total_num_tokens,
747
+ page_size=self.page_size,
736
748
  dtype=self.kv_cache_dtype,
737
749
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
738
750
  head_dim=self.model_config.head_dim,
@@ -744,6 +756,7 @@ class ModelRunner:
744
756
  else:
745
757
  self.token_to_kv_pool = MHATokenToKVPool(
746
758
  self.max_total_num_tokens,
759
+ page_size=self.page_size,
747
760
  dtype=self.kv_cache_dtype,
748
761
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
749
762
  head_dim=self.model_config.head_dim,
@@ -753,12 +766,21 @@ class ModelRunner:
753
766
  )
754
767
 
755
768
  if self.token_to_kv_pool_allocator is None:
756
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
757
- self.max_total_num_tokens,
758
- dtype=self.kv_cache_dtype,
759
- device=self.device,
760
- kvcache=self.token_to_kv_pool,
761
- )
769
+ if self.page_size == 1:
770
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
771
+ self.max_total_num_tokens,
772
+ dtype=self.kv_cache_dtype,
773
+ device=self.device,
774
+ kvcache=self.token_to_kv_pool,
775
+ )
776
+ else:
777
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
778
+ self.max_total_num_tokens,
779
+ page_size=self.page_size,
780
+ dtype=self.kv_cache_dtype,
781
+ device=self.device,
782
+ kvcache=self.token_to_kv_pool,
783
+ )
762
784
  else:
763
785
  assert self.is_draft_worker
764
786
 
@@ -779,10 +801,13 @@ class ModelRunner:
779
801
  def init_attention_backend(self):
780
802
  """Init attention kernel backend."""
781
803
  if self.server_args.attention_backend == "flashinfer":
804
+ from sglang.srt.layers.attention.flashinfer_backend import (
805
+ FlashInferAttnBackend,
806
+ )
807
+
782
808
  # Init streams
783
809
  if self.server_args.speculative_algorithm == "EAGLE":
784
810
  self.plan_stream_for_flashinfer = torch.cuda.Stream()
785
-
786
811
  self.attn_backend = FlashInferAttnBackend(self)
787
812
  elif self.server_args.attention_backend == "triton":
788
813
  assert self.sliding_window_size is None, (
@@ -794,12 +819,26 @@ class ModelRunner:
794
819
  "Please use `--attention-backend flashinfer`."
795
820
  )
796
821
  if self.server_args.enable_double_sparsity:
822
+ from sglang.srt.layers.attention.double_sparsity_backend import (
823
+ DoubleSparseAttnBackend,
824
+ )
825
+
797
826
  self.attn_backend = DoubleSparseAttnBackend(self)
798
827
  else:
828
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
829
+
799
830
  self.attn_backend = TritonAttnBackend(self)
800
831
  elif self.server_args.attention_backend == "torch_native":
832
+ from sglang.srt.layers.attention.torch_native_backend import (
833
+ TorchNativeAttnBackend,
834
+ )
835
+
801
836
  self.attn_backend = TorchNativeAttnBackend(self)
802
837
  elif self.server_args.attention_backend == "flashinfer_mla":
838
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
839
+ FlashInferMLAAttnBackend,
840
+ )
841
+
803
842
  self.attn_backend = FlashInferMLAAttnBackend(self)
804
843
  else:
805
844
  raise ValueError(
@@ -928,45 +967,6 @@ class ModelRunner:
928
967
  sampling_info.update_regex_vocab_mask()
929
968
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
930
969
 
931
- def update_output_logprobs(
932
- self,
933
- logits_output: LogitsProcessorOutput,
934
- sampling_info: SamplingBatchInfo,
935
- top_logprobs_nums: List[int],
936
- token_ids_logprobs: List[int],
937
- next_token_ids: torch.Tensor,
938
- *,
939
- num_tokens_per_req: List[int],
940
- ):
941
- """Update the logits_output's output logprob based on next_token_ids
942
-
943
- Args:
944
- logits_output: The logits output from the model forward
945
- sampling_info: Sampling info for logprob calculation
946
- top_logprobs_nums: Number of logprobs per request.
947
- next_token_ids: Next token ids.
948
- num_tokens_per_req: The number of tokens per request.
949
-
950
- Returns:
951
- A list of next_token_ids
952
- """
953
- self._preprocess_logits(logits_output, sampling_info)
954
- # We should repeat top_logprobs_nums to match num_tokens_per_req.
955
- top_logprobs_nums_repeat_interleaved = []
956
- token_ids_logprobs_repeat_interleaved = []
957
- for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
958
- top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
959
- for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
960
- token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
961
- self.sampler(
962
- logits_output,
963
- sampling_info,
964
- True,
965
- top_logprobs_nums_repeat_interleaved,
966
- token_ids_logprobs_repeat_interleaved,
967
- batch_next_token_ids=next_token_ids,
968
- )
969
-
970
970
  def sample(
971
971
  self,
972
972
  logits_output: LogitsProcessorOutput,
@@ -48,6 +48,7 @@ from sglang.srt.model_loader.weight_utils import (
48
48
  safetensors_weights_iterator,
49
49
  )
50
50
  from sglang.srt.utils import (
51
+ get_bool_env_var,
51
52
  get_device_capability,
52
53
  is_pin_memory_available,
53
54
  set_weight_attrs,
@@ -197,7 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
197
198
 
198
199
  Returns the path to the downloaded model, or None if the model is not
199
200
  downloaded from ModelScope."""
200
- if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
201
+ if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
201
202
  # download model from ModelScope hub,
202
203
  # lazy import so that modelscope is not required for normal use.
203
204
  # pylint: disable=C.
@@ -455,7 +455,7 @@ def pt_weights_iterator(
455
455
  disable=not enable_tqdm,
456
456
  bar_format=_BAR_FORMAT,
457
457
  ):
458
- state = torch.load(bin_file, map_location="cpu")
458
+ state = torch.load(bin_file, map_location="cpu", weights_only=True)
459
459
  yield from state.items()
460
460
  del state
461
461
  torch.cuda.empty_cache()