sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -33,9 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
- from sglang.srt.utils import is_hip
36
+ from sglang.srt.utils import get_available_gpu_memory, is_hip
37
37
 
38
- is_hip_ = is_hip()
38
+ _is_hip = is_hip()
39
39
 
40
40
  if TYPE_CHECKING:
41
41
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
119
119
  else:
120
120
  capture_bs = list(range(1, 33))
121
121
 
122
- if is_hip_:
122
+ if _is_hip:
123
123
  capture_bs += [i * 8 for i in range(21, 33)]
124
124
 
125
125
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -174,6 +174,7 @@ class CudaGraphRunner:
174
174
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175
175
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176
176
  self.enable_dp_attention = model_runner.server_args.enable_dp_attention
177
+ self.speculative_algorithm = model_runner.server_args.speculative_algorithm
177
178
  self.tp_size = model_runner.server_args.tp_size
178
179
  self.dp_size = model_runner.server_args.dp_size
179
180
 
@@ -236,7 +237,7 @@ class CudaGraphRunner:
236
237
  if self.enable_dp_attention:
237
238
  self.gathered_buffer = torch.zeros(
238
239
  (
239
- self.max_bs * self.dp_size,
240
+ self.max_bs * self.dp_size * self.num_tokens_per_bs,
240
241
  self.model_runner.model_config.hidden_size,
241
242
  ),
242
243
  dtype=self.model_runner.dtype,
@@ -264,21 +265,24 @@ class CudaGraphRunner:
264
265
  def model_capture_mode(self):
265
266
  if hasattr(self.model_runner.model, "capture_mode"):
266
267
  self.model_runner.model.capture_mode = True
268
+ if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
269
+ self.model_runner.token_to_kv_pool.capture_mode = True
267
270
 
268
271
  yield
269
272
 
270
273
  if hasattr(self.model_runner.model, "capture_mode"):
271
274
  self.model_runner.model.capture_mode = False
275
+ if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
276
+ self.model_runner.token_to_kv_pool.capture_mode = False
272
277
 
273
278
  def can_run(self, forward_batch: ForwardBatch):
274
279
  if self.enable_dp_attention:
275
- min_num_tokens, max_num_tokens = min(
276
- forward_batch.global_num_tokens_cpu
277
- ), max(forward_batch.global_num_tokens_cpu)
280
+ total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
281
+
278
282
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
279
- (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
283
+ total_global_tokens in self.graphs
280
284
  if self.disable_padding
281
- else max_num_tokens <= self.max_bs
285
+ else total_global_tokens <= self.max_bs
282
286
  )
283
287
  else:
284
288
  is_bs_supported = (
@@ -300,12 +304,26 @@ class CudaGraphRunner:
300
304
  def capture(self):
301
305
  with graph_capture() as graph_capture_context:
302
306
  self.stream = graph_capture_context.stream
307
+ avail_mem = get_available_gpu_memory(
308
+ self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
309
+ )
310
+ # Reverse the order to enable better memory sharing across cuda graphs.
303
311
  capture_range = (
304
- tqdm.tqdm(self.capture_bs)
312
+ tqdm.tqdm(list(reversed(self.capture_bs)))
305
313
  if get_tensor_model_parallel_rank() == 0
306
- else self.capture_bs
314
+ else reversed(self.capture_bs)
307
315
  )
308
316
  for bs in capture_range:
317
+ if get_tensor_model_parallel_rank() == 0:
318
+ avail_mem = get_available_gpu_memory(
319
+ self.model_runner.device,
320
+ self.model_runner.gpu_id,
321
+ empty_cache=False,
322
+ )
323
+ capture_range.set_description(
324
+ f"Capturing batches ({avail_mem=:.2f} GB)"
325
+ )
326
+
309
327
  with patch_model(
310
328
  self.model_runner.model,
311
329
  bs in self.compile_bs,
@@ -340,8 +358,18 @@ class CudaGraphRunner:
340
358
  mrope_positions = self.mrope_positions[:, :bs]
341
359
 
342
360
  if self.enable_dp_attention:
343
- global_num_tokens = [bs] * self.tp_size
344
- gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
361
+ self.global_num_tokens_gpu.copy_(
362
+ torch.tensor(
363
+ [
364
+ num_tokens // self.dp_size + (i < bs % self.dp_size)
365
+ for i in range(self.dp_size)
366
+ ],
367
+ dtype=torch.int32,
368
+ device=input_ids.device,
369
+ )
370
+ )
371
+ global_num_tokens = self.global_num_tokens_gpu
372
+ gathered_buffer = self.gathered_buffer[:num_tokens]
345
373
  else:
346
374
  global_num_tokens = None
347
375
  gathered_buffer = None
@@ -366,7 +394,7 @@ class CudaGraphRunner:
366
394
  encoder_lens=encoder_lens,
367
395
  return_logprob=False,
368
396
  positions=positions,
369
- global_num_tokens_cpu=global_num_tokens,
397
+ global_num_tokens_gpu=global_num_tokens,
370
398
  gathered_buffer=gathered_buffer,
371
399
  mrope_positions=mrope_positions,
372
400
  spec_algorithm=self.model_runner.spec_algorithm,
@@ -387,6 +415,9 @@ class CudaGraphRunner:
387
415
 
388
416
  # Run and capture
389
417
  def run_once():
418
+ # Clean intermediate result cache for DP attention
419
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
420
+
390
421
  logits_output = forward(input_ids, forward_batch.positions, forward_batch)
391
422
  return logits_output.next_token_logits, logits_output.hidden_states
392
423
 
@@ -421,7 +452,7 @@ class CudaGraphRunner:
421
452
  self.capture_hidden_mode = hidden_mode_from_spec_info
422
453
  self.capture()
423
454
 
424
- def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
455
+ def replay_prepare(self, forward_batch: ForwardBatch):
425
456
  self.recapture_if_needed(forward_batch)
426
457
 
427
458
  raw_bs = forward_batch.batch_size
@@ -430,7 +461,7 @@ class CudaGraphRunner:
430
461
  # Pad
431
462
  if self.enable_dp_attention:
432
463
  index = bisect.bisect_left(
433
- self.capture_bs, max(forward_batch.global_num_tokens_cpu)
464
+ self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
434
465
  )
435
466
  else:
436
467
  index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -454,6 +485,8 @@ class CudaGraphRunner:
454
485
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
455
486
  if forward_batch.mrope_positions is not None:
456
487
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
488
+ if self.enable_dp_attention:
489
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
457
490
 
458
491
  if hasattr(forward_batch.spec_info, "hidden_states"):
459
492
  self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
@@ -470,14 +503,29 @@ class CudaGraphRunner:
470
503
  seq_lens_cpu=self.seq_lens_cpu,
471
504
  )
472
505
 
506
+ # Store fields
507
+ self.raw_bs = raw_bs
508
+ self.raw_num_token = raw_num_token
509
+ self.bs = bs
510
+
511
+ def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
512
+ if not skip_attn_backend_init:
513
+ self.replay_prepare(forward_batch)
514
+ else:
515
+ # In speculative decoding, these two fields are still needed.
516
+ self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
517
+ self.positions[: self.raw_num_token].copy_(forward_batch.positions)
518
+
473
519
  # Replay
474
- self.graphs[bs].replay()
475
- next_token_logits, hidden_states = self.output_buffers[bs]
520
+ self.graphs[self.bs].replay()
521
+ next_token_logits, hidden_states = self.output_buffers[self.bs]
476
522
 
477
523
  logits_output = LogitsProcessorOutput(
478
- next_token_logits=next_token_logits[:raw_num_token],
524
+ next_token_logits=next_token_logits[: self.raw_num_token],
479
525
  hidden_states=(
480
- hidden_states[:raw_num_token] if hidden_states is not None else None
526
+ hidden_states[: self.raw_num_token]
527
+ if hidden_states is not None
528
+ else None
481
529
  ),
482
530
  )
483
531
  return logits_output
@@ -43,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
45
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
+ from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
47
47
  from sglang.srt.model_executor.model_runner import ModelRunner
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
49
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
51
51
 
52
52
 
53
53
  class ForwardMode(IntEnum):
54
- # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
55
- PREFILL = auto()
56
54
  # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
55
+ # It is also called "prefill" in common terminology.
57
56
  EXTEND = auto()
58
57
  # Decode one token.
59
58
  DECODE = auto()
@@ -153,6 +152,12 @@ class ForwardBatch:
153
152
  top_logprobs_nums: Optional[List[int]] = None
154
153
  token_ids_logprobs: Optional[List[List[int]]] = None
155
154
 
155
+ # For logits and logprobs post processing
156
+ temp_scaled_logprobs: bool = False
157
+ temperature: torch.Tensor = None
158
+ top_p_normalized_logprobs: bool = False
159
+ top_p: torch.Tensor = None
160
+
156
161
  # Position information
157
162
  positions: torch.Tensor = None
158
163
 
@@ -189,7 +194,7 @@ class ForwardBatch:
189
194
 
190
195
  # Attention backend
191
196
  req_to_token_pool: ReqToTokenPool = None
192
- token_to_kv_pool: BaseTokenToKVPool = None
197
+ token_to_kv_pool: KVCache = None
193
198
  attn_backend: AttentionBackend = None
194
199
 
195
200
  # For DP attention
@@ -229,7 +234,6 @@ class ForwardBatch:
229
234
  extend_input_logprob_token_ids_gpu = (
230
235
  batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
231
236
  )
232
-
233
237
  ret = cls(
234
238
  forward_mode=batch.forward_mode,
235
239
  batch_size=len(batch.seq_lens),
@@ -259,15 +263,24 @@ class ForwardBatch:
259
263
  extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
260
264
  )
261
265
 
266
+ # For DP attention
262
267
  if batch.global_num_tokens is not None:
263
268
  ret.global_num_tokens_cpu = batch.global_num_tokens
264
- max_len = max(ret.global_num_tokens_cpu)
269
+ ret.global_num_tokens_gpu = torch.tensor(
270
+ batch.global_num_tokens, dtype=torch.int64
271
+ ).to(device, non_blocking=True)
272
+
273
+ ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
274
+ ret.global_num_tokens_for_logprob_gpu = torch.tensor(
275
+ batch.global_num_tokens_for_logprob, dtype=torch.int64
276
+ ).to(device, non_blocking=True)
277
+
278
+ sum_len = sum(batch.global_num_tokens)
265
279
  ret.gathered_buffer = torch.zeros(
266
- (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
280
+ (sum_len, model_runner.model_config.hidden_size),
267
281
  dtype=model_runner.dtype,
268
282
  device=device,
269
283
  )
270
-
271
284
  if ret.forward_mode.is_idle():
272
285
  ret.positions = torch.empty((0,), device=device)
273
286
  return ret
@@ -417,8 +430,8 @@ def compute_position_kernel(
417
430
  prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
418
431
  seq_len = tl.load(extend_seq_lens + pid)
419
432
 
420
- # TODO: optimize this?
421
- cumsum_start = 0
433
+ # NOTE: This can be slow for large bs
434
+ cumsum_start = tl.cast(0, tl.int64)
422
435
  for i in range(pid):
423
436
  cumsum_start += tl.load(extend_seq_lens + i)
424
437
 
@@ -35,17 +35,13 @@ from sglang.srt.distributed import (
35
35
  set_custom_all_reduce,
36
36
  )
37
37
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
38
- from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
39
- from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
- from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
41
- from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
42
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
43
38
  from sglang.srt.layers.dp_attention import (
44
39
  get_attention_tp_group,
45
40
  get_attention_tp_size,
46
41
  initialize_dp_attention,
47
42
  )
48
43
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
+ from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
49
45
  from sglang.srt.layers.sampler import Sampler
50
46
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
51
47
  from sglang.srt.lora.lora_manager import LoRAManager
@@ -57,9 +53,16 @@ from sglang.srt.mem_cache.memory_pool import (
57
53
  ReqToTokenPool,
58
54
  TokenToKVPoolAllocator,
59
55
  )
56
+ from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
60
57
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
61
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
59
  from sglang.srt.model_loader import get_model
60
+ from sglang.srt.model_loader.loader import (
61
+ DefaultModelLoader,
62
+ device_loading_context,
63
+ get_model_loader,
64
+ )
65
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
63
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
64
67
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
65
68
  from sglang.srt.server_args import ServerArgs
@@ -77,11 +80,9 @@ from sglang.srt.utils import (
77
80
  set_cpu_offload_max_bytes,
78
81
  set_cuda_arch,
79
82
  )
80
- from sglang.utils import get_exception_traceback
81
83
 
82
84
  logger = logging.getLogger(__name__)
83
85
 
84
-
85
86
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
86
87
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
87
88
 
@@ -118,6 +119,7 @@ class ModelRunner:
118
119
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
119
120
  server_args.speculative_algorithm
120
121
  )
122
+ self.page_size = server_args.page_size
121
123
  self.req_to_token_pool = req_to_token_pool
122
124
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
123
125
 
@@ -160,6 +162,11 @@ class ModelRunner:
160
162
  # Get memory before model loading
161
163
  min_per_gpu_memory = self.init_torch_distributed()
162
164
 
165
+ # If it is a draft model tp_group can be different.
166
+ self.initialize(min_per_gpu_memory)
167
+
168
+ def initialize(self, min_per_gpu_memory: float):
169
+ server_args = self.server_args
163
170
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
164
171
  enable=self.server_args.enable_memory_saver
165
172
  )
@@ -299,15 +306,16 @@ class ModelRunner:
299
306
  min_per_gpu_memory = get_available_gpu_memory(
300
307
  self.device, self.gpu_id, distributed=self.tp_size > 1
301
308
  )
302
- local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
303
309
  self.tp_group = get_tp_group()
304
310
  self.attention_tp_group = get_attention_tp_group()
305
311
 
306
312
  # Check memory for tensor parallelism
313
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
307
314
  if self.tp_size > 1:
308
315
  if min_per_gpu_memory < local_gpu_memory * 0.9:
309
316
  raise ValueError(
310
- "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
317
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
318
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
311
319
  )
312
320
 
313
321
  logger.info(
@@ -347,6 +355,8 @@ class ModelRunner:
347
355
  # Load the model
348
356
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
349
357
  monkey_patch_vllm_parallel_state()
358
+ monkey_patch_isinstance_for_vllm_base_layer()
359
+
350
360
  with self.memory_saver_adapter.region():
351
361
  self.model = get_model(
352
362
  model_config=self.model_config,
@@ -354,6 +364,7 @@ class ModelRunner:
354
364
  device_config=DeviceConfig(self.device),
355
365
  )
356
366
  monkey_patch_vllm_parallel_state(reverse=True)
367
+ monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
357
368
 
358
369
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
359
370
  if self.server_args.quantization_param_path is not None:
@@ -411,13 +422,6 @@ class ModelRunner:
411
422
  self, model_path: str, load_format: str
412
423
  ) -> tuple[bool, str]:
413
424
  """Update engine weights in-place from the disk."""
414
- from sglang.srt.model_loader.loader import (
415
- DefaultModelLoader,
416
- device_loading_context,
417
- get_model_loader,
418
- )
419
- from sglang.srt.model_loader.utils import set_default_torch_dtype
420
-
421
425
  logger.info(
422
426
  f"Update engine weights online from disk begin. "
423
427
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -427,7 +431,7 @@ class ModelRunner:
427
431
  self.model_config.model_path = model_path
428
432
  load_config = LoadConfig(load_format=load_format)
429
433
 
430
- # Only support vllm DefaultModelLoader for now
434
+ # Only support DefaultModelLoader for now
431
435
  loader = get_model_loader(load_config)
432
436
  if not isinstance(loader, DefaultModelLoader):
433
437
  message = f"Failed to get model loader: {loader}."
@@ -701,6 +705,12 @@ class ModelRunner:
701
705
  )
702
706
  self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
703
707
 
708
+ self.max_total_num_tokens = (
709
+ self.max_total_num_tokens
710
+ // self.server_args.page_size
711
+ * self.server_args.page_size
712
+ )
713
+
704
714
  if self.max_total_num_tokens <= 0:
705
715
  raise RuntimeError(
706
716
  "Not enough memory. Please try to increase --mem-fraction-static."
@@ -723,6 +733,7 @@ class ModelRunner:
723
733
  ):
724
734
  self.token_to_kv_pool = MLATokenToKVPool(
725
735
  self.max_total_num_tokens,
736
+ page_size=self.page_size,
726
737
  dtype=self.kv_cache_dtype,
727
738
  kv_lora_rank=self.model_config.kv_lora_rank,
728
739
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
@@ -733,6 +744,7 @@ class ModelRunner:
733
744
  elif self.server_args.enable_double_sparsity:
734
745
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
735
746
  self.max_total_num_tokens,
747
+ page_size=self.page_size,
736
748
  dtype=self.kv_cache_dtype,
737
749
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
738
750
  head_dim=self.model_config.head_dim,
@@ -744,6 +756,7 @@ class ModelRunner:
744
756
  else:
745
757
  self.token_to_kv_pool = MHATokenToKVPool(
746
758
  self.max_total_num_tokens,
759
+ page_size=self.page_size,
747
760
  dtype=self.kv_cache_dtype,
748
761
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
749
762
  head_dim=self.model_config.head_dim,
@@ -753,12 +766,21 @@ class ModelRunner:
753
766
  )
754
767
 
755
768
  if self.token_to_kv_pool_allocator is None:
756
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
757
- self.max_total_num_tokens,
758
- dtype=self.kv_cache_dtype,
759
- device=self.device,
760
- kvcache=self.token_to_kv_pool,
761
- )
769
+ if self.page_size == 1:
770
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
771
+ self.max_total_num_tokens,
772
+ dtype=self.kv_cache_dtype,
773
+ device=self.device,
774
+ kvcache=self.token_to_kv_pool,
775
+ )
776
+ else:
777
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
778
+ self.max_total_num_tokens,
779
+ page_size=self.page_size,
780
+ dtype=self.kv_cache_dtype,
781
+ device=self.device,
782
+ kvcache=self.token_to_kv_pool,
783
+ )
762
784
  else:
763
785
  assert self.is_draft_worker
764
786
 
@@ -779,10 +801,13 @@ class ModelRunner:
779
801
  def init_attention_backend(self):
780
802
  """Init attention kernel backend."""
781
803
  if self.server_args.attention_backend == "flashinfer":
804
+ from sglang.srt.layers.attention.flashinfer_backend import (
805
+ FlashInferAttnBackend,
806
+ )
807
+
782
808
  # Init streams
783
809
  if self.server_args.speculative_algorithm == "EAGLE":
784
810
  self.plan_stream_for_flashinfer = torch.cuda.Stream()
785
-
786
811
  self.attn_backend = FlashInferAttnBackend(self)
787
812
  elif self.server_args.attention_backend == "triton":
788
813
  assert self.sliding_window_size is None, (
@@ -794,12 +819,26 @@ class ModelRunner:
794
819
  "Please use `--attention-backend flashinfer`."
795
820
  )
796
821
  if self.server_args.enable_double_sparsity:
822
+ from sglang.srt.layers.attention.double_sparsity_backend import (
823
+ DoubleSparseAttnBackend,
824
+ )
825
+
797
826
  self.attn_backend = DoubleSparseAttnBackend(self)
798
827
  else:
828
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
829
+
799
830
  self.attn_backend = TritonAttnBackend(self)
800
831
  elif self.server_args.attention_backend == "torch_native":
832
+ from sglang.srt.layers.attention.torch_native_backend import (
833
+ TorchNativeAttnBackend,
834
+ )
835
+
801
836
  self.attn_backend = TorchNativeAttnBackend(self)
802
837
  elif self.server_args.attention_backend == "flashinfer_mla":
838
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
839
+ FlashInferMLAAttnBackend,
840
+ )
841
+
803
842
  self.attn_backend = FlashInferMLAAttnBackend(self)
804
843
  else:
805
844
  raise ValueError(
@@ -928,45 +967,6 @@ class ModelRunner:
928
967
  sampling_info.update_regex_vocab_mask()
929
968
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
930
969
 
931
- def update_output_logprobs(
932
- self,
933
- logits_output: LogitsProcessorOutput,
934
- sampling_info: SamplingBatchInfo,
935
- top_logprobs_nums: List[int],
936
- token_ids_logprobs: List[int],
937
- next_token_ids: torch.Tensor,
938
- *,
939
- num_tokens_per_req: List[int],
940
- ):
941
- """Update the logits_output's output logprob based on next_token_ids
942
-
943
- Args:
944
- logits_output: The logits output from the model forward
945
- sampling_info: Sampling info for logprob calculation
946
- top_logprobs_nums: Number of logprobs per request.
947
- next_token_ids: Next token ids.
948
- num_tokens_per_req: The number of tokens per request.
949
-
950
- Returns:
951
- A list of next_token_ids
952
- """
953
- self._preprocess_logits(logits_output, sampling_info)
954
- # We should repeat top_logprobs_nums to match num_tokens_per_req.
955
- top_logprobs_nums_repeat_interleaved = []
956
- token_ids_logprobs_repeat_interleaved = []
957
- for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
958
- top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
959
- for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
960
- token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
961
- self.sampler(
962
- logits_output,
963
- sampling_info,
964
- True,
965
- top_logprobs_nums_repeat_interleaved,
966
- token_ids_logprobs_repeat_interleaved,
967
- batch_next_token_ids=next_token_ids,
968
- )
969
-
970
970
  def sample(
971
971
  self,
972
972
  logits_output: LogitsProcessorOutput,
@@ -48,6 +48,7 @@ from sglang.srt.model_loader.weight_utils import (
48
48
  safetensors_weights_iterator,
49
49
  )
50
50
  from sglang.srt.utils import (
51
+ get_bool_env_var,
51
52
  get_device_capability,
52
53
  is_pin_memory_available,
53
54
  set_weight_attrs,
@@ -197,7 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
197
198
 
198
199
  Returns the path to the downloaded model, or None if the model is not
199
200
  downloaded from ModelScope."""
200
- if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
201
+ if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
201
202
  # download model from ModelScope hub,
202
203
  # lazy import so that modelscope is not required for normal use.
203
204
  # pylint: disable=C.
@@ -455,7 +455,7 @@ def pt_weights_iterator(
455
455
  disable=not enable_tqdm,
456
456
  bar_format=_BAR_FORMAT,
457
457
  ):
458
- state = torch.load(bin_file, map_location="cpu")
458
+ state = torch.load(bin_file, map_location="cpu", weights_only=True)
459
459
  yield from state.items()
460
460
  del state
461
461
  torch.cuda.empty_cache()