sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import inspect
19
20
  import os
20
21
  from contextlib import contextmanager
21
22
  from typing import TYPE_CHECKING, Callable
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
33
34
  CaptureHiddenMode,
34
35
  ForwardBatch,
35
36
  ForwardMode,
37
+ PPProxyTensors,
36
38
  )
37
39
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
40
  from sglang.srt.utils import (
39
41
  get_available_gpu_memory,
40
42
  get_device_memory_capacity,
41
43
  is_hip,
44
+ rank0_log,
42
45
  )
43
46
 
44
47
  if TYPE_CHECKING:
@@ -135,7 +138,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
135
138
 
136
139
  gpu_mem = get_device_memory_capacity()
137
140
  # Batch size of each rank will not become so large when DP is on
138
- if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
141
+ if gpu_mem is not None and gpu_mem > 96 * 1024:
139
142
  capture_bs += list(range(160, 257, 8))
140
143
 
141
144
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -188,10 +191,11 @@ class CudaGraphRunner:
188
191
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
189
192
  self.tp_size = model_runner.server_args.tp_size
190
193
  self.dp_size = model_runner.server_args.dp_size
194
+ self.pp_size = model_runner.server_args.pp_size
191
195
 
192
196
  # Batch sizes to capture
193
197
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
194
-
198
+ rank0_log(f"Capture cuda graph bs {self.capture_bs}")
195
199
  self.capture_forward_mode = ForwardMode.DECODE
196
200
  self.capture_hidden_mode = CaptureHiddenMode.NULL
197
201
  self.num_tokens_per_bs = 1
@@ -220,6 +224,9 @@ class CudaGraphRunner:
220
224
  if self.enable_torch_compile:
221
225
  set_torch_compile_config()
222
226
 
227
+ if self.model_runner.server_args.lora_paths is not None:
228
+ self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
229
+
223
230
  # Graph inputs
224
231
  with torch.device("cuda"):
225
232
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -231,6 +238,19 @@ class CudaGraphRunner:
231
238
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
232
239
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
233
240
 
241
+ # pipeline parallelism
242
+ if self.pp_size > 1:
243
+ self.pp_proxy_tensors = {
244
+ "hidden_states": torch.zeros(
245
+ (self.max_bs, self.model_runner.model_config.hidden_size),
246
+ dtype=torch.bfloat16,
247
+ ),
248
+ "residual": torch.zeros(
249
+ (self.max_bs, self.model_runner.model_config.hidden_size),
250
+ dtype=torch.bfloat16,
251
+ ),
252
+ }
253
+
234
254
  # Speculative_inference
235
255
  if (
236
256
  model_runner.spec_algorithm.is_eagle3()
@@ -381,6 +401,12 @@ class CudaGraphRunner:
381
401
  encoder_lens = None
382
402
  mrope_positions = self.mrope_positions[:, :bs]
383
403
 
404
+ # pipeline parallelism
405
+ if self.pp_size > 1:
406
+ pp_proxy_tensors = PPProxyTensors(
407
+ {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
408
+ )
409
+
384
410
  if self.enable_dp_attention or self.enable_sp_layernorm:
385
411
  self.global_num_tokens_gpu.copy_(
386
412
  torch.tensor(
@@ -403,6 +429,13 @@ class CudaGraphRunner:
403
429
  self.capture_hidden_mode = (
404
430
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
405
431
  )
432
+ if self.model_runner.server_args.lora_paths is not None:
433
+ # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
434
+ # different logic to handle lora, so we need to set `lora_paths` to a list of non-None
435
+ # values if lora is enabled.
436
+ lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
437
+ else:
438
+ lora_paths = None
406
439
 
407
440
  forward_batch = ForwardBatch(
408
441
  forward_mode=self.capture_forward_mode,
@@ -424,8 +457,12 @@ class CudaGraphRunner:
424
457
  spec_algorithm=self.model_runner.spec_algorithm,
425
458
  spec_info=spec_info,
426
459
  capture_hidden_mode=self.capture_hidden_mode,
460
+ lora_paths=lora_paths,
427
461
  )
428
462
 
463
+ if lora_paths is not None:
464
+ self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
465
+
429
466
  # Attention backend
430
467
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
431
468
  bs,
@@ -442,8 +479,20 @@ class CudaGraphRunner:
442
479
  # Clean intermediate result cache for DP attention
443
480
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
444
481
 
445
- logits_output = forward(input_ids, forward_batch.positions, forward_batch)
446
- return logits_output.next_token_logits, logits_output.hidden_states
482
+ kwargs = {}
483
+ if (
484
+ self.pp_size > 1
485
+ and "pp_proxy_tensors" in inspect.signature(forward).parameters
486
+ ):
487
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
488
+
489
+ logits_output_or_pp_proxy_tensors = forward(
490
+ input_ids,
491
+ forward_batch.positions,
492
+ forward_batch,
493
+ **kwargs,
494
+ )
495
+ return logits_output_or_pp_proxy_tensors
447
496
 
448
497
  for _ in range(2):
449
498
  torch.cuda.synchronize()
@@ -476,7 +525,11 @@ class CudaGraphRunner:
476
525
  self.capture_hidden_mode = hidden_mode_from_spec_info
477
526
  self.capture()
478
527
 
479
- def replay_prepare(self, forward_batch: ForwardBatch):
528
+ def replay_prepare(
529
+ self,
530
+ forward_batch: ForwardBatch,
531
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
532
+ ):
480
533
  self.recapture_if_needed(forward_batch)
481
534
 
482
535
  raw_bs = forward_batch.batch_size
@@ -505,6 +558,11 @@ class CudaGraphRunner:
505
558
  self.seq_lens_cpu.fill_(1)
506
559
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
507
560
 
561
+ if pp_proxy_tensors:
562
+ for key in self.pp_proxy_tensors.keys():
563
+ dim = pp_proxy_tensors[key].shape[0]
564
+ self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
565
+
508
566
  if self.is_encoder_decoder:
509
567
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
510
568
  if forward_batch.mrope_positions is not None:
@@ -533,10 +591,13 @@ class CudaGraphRunner:
533
591
  self.bs = bs
534
592
 
535
593
  def replay(
536
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
537
- ) -> LogitsProcessorOutput:
594
+ self,
595
+ forward_batch: ForwardBatch,
596
+ skip_attn_backend_init: bool = False,
597
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
598
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
538
599
  if not skip_attn_backend_init:
539
- self.replay_prepare(forward_batch)
600
+ self.replay_prepare(forward_batch, pp_proxy_tensors)
540
601
  else:
541
602
  # In speculative decoding, these two fields are still needed.
542
603
  self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
@@ -544,17 +605,19 @@ class CudaGraphRunner:
544
605
 
545
606
  # Replay
546
607
  self.graphs[self.bs].replay()
547
- next_token_logits, hidden_states = self.output_buffers[self.bs]
548
-
549
- logits_output = LogitsProcessorOutput(
550
- next_token_logits=next_token_logits[: self.raw_num_token],
551
- hidden_states=(
552
- hidden_states[: self.raw_num_token]
553
- if hidden_states is not None
554
- else None
555
- ),
556
- )
557
- return logits_output
608
+ output = self.output_buffers[self.bs]
609
+ if isinstance(output, LogitsProcessorOutput):
610
+ return LogitsProcessorOutput(
611
+ next_token_logits=output.next_token_logits[: self.raw_num_token],
612
+ hidden_states=(
613
+ output.hidden_states[: self.raw_num_token]
614
+ if output.hidden_states is not None
615
+ else None
616
+ ),
617
+ )
618
+ else:
619
+ assert isinstance(output, PPProxyTensors)
620
+ return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
558
621
 
559
622
  def get_spec_info(self, num_tokens: int):
560
623
  spec_info = None
@@ -31,7 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
- from typing import TYPE_CHECKING, List, Optional, Union
34
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
35
35
 
36
36
  import torch
37
37
  import triton
@@ -585,6 +585,36 @@ class ForwardBatch:
585
585
  self.prepare_chunked_kv_indices(device)
586
586
 
587
587
 
588
+ class PPProxyTensors:
589
+ # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
590
+ tensors: Dict[str, torch.Tensor]
591
+
592
+ def __init__(self, tensors):
593
+ # manually define this function, so that
594
+ # Dynamo knows `IntermediateTensors()` comes from this file.
595
+ # Otherwise, dataclass will generate this function by evaluating
596
+ # a string, and we will lose the information about the source file.
597
+ self.tensors = tensors
598
+
599
+ def __getitem__(self, key: Union[str, slice]):
600
+ if isinstance(key, str):
601
+ return self.tensors[key]
602
+ elif isinstance(key, slice):
603
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
604
+
605
+ def __setitem__(self, key: str, value: torch.Tensor):
606
+ self.tensors[key] = value
607
+
608
+ def __len__(self):
609
+ return len(self.tensors)
610
+
611
+ def __eq__(self, other: object):
612
+ return isinstance(other, self.__class__) and self
613
+
614
+ def __repr__(self) -> str:
615
+ return f"PPProxyTensors(tensors={self.tensors})"
616
+
617
+
588
618
  def compute_position_triton(
589
619
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
590
620
  ):