sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -2,15 +2,18 @@ import logging
2
2
  from typing import List
3
3
 
4
4
  import torch
5
+ import torch.distributed as dist
5
6
  from torch import nn
6
7
 
8
+ from sglang.srt.distributed import get_tensor_model_parallel_group
9
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
7
10
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
12
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
- from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
13
+ from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
11
14
 
12
- if is_flashinfer_available():
13
- from flashinfer.sampling import (
15
+ if is_cuda_available():
16
+ from sgl_kernel import (
14
17
  min_p_sampling_from_probs,
15
18
  top_k_renorm_prob,
16
19
  top_k_top_p_sampling_from_probs,
@@ -20,11 +23,17 @@ if is_flashinfer_available():
20
23
 
21
24
  logger = logging.getLogger(__name__)
22
25
 
26
+ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
27
+
23
28
 
24
29
  class Sampler(nn.Module):
25
30
  def __init__(self):
26
31
  super().__init__()
27
32
  self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
33
+ self.tp_sync_group = get_tensor_model_parallel_group().device_group
34
+
35
+ if global_server_args_dict["enable_dp_attention"]:
36
+ self.tp_sync_group = get_attention_tp_group().device_group
28
37
 
29
38
  def forward(
30
39
  self,
@@ -35,6 +44,10 @@ class Sampler(nn.Module):
35
44
  ):
36
45
  logits = logits_output.next_token_logits
37
46
 
47
+ # Apply the custom logit processors if registered in the sampling info.
48
+ if sampling_info.has_custom_logit_processor:
49
+ self._apply_custom_logit_processor(logits, sampling_info)
50
+
38
51
  if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
39
52
  logger.warning("Detected errors during sampling! NaN in the logits.")
40
53
  logits = torch.where(
@@ -104,8 +117,6 @@ class Sampler(nn.Module):
104
117
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
105
118
  )
106
119
 
107
- batch_next_token_ids = batch_next_token_ids.to(torch.int32)
108
-
109
120
  # Attach logprobs to logits_output (in-place modification)
110
121
  if return_logprob:
111
122
  if any(x > 0 for x in top_logprobs_nums):
@@ -119,7 +130,54 @@ class Sampler(nn.Module):
119
130
  batch_next_token_ids,
120
131
  ]
121
132
 
122
- return batch_next_token_ids
133
+ if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
134
+ # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
135
+ # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
136
+ # the last all-reduce, the last lm_head matmul, and all sampling kernels.
137
+ # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
138
+ # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
139
+ # When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
140
+
141
+ torch.distributed.all_reduce(
142
+ batch_next_token_ids,
143
+ op=dist.ReduceOp.MIN,
144
+ group=self.tp_sync_group,
145
+ )
146
+
147
+ return batch_next_token_ids.to(torch.int32)
148
+
149
+ def _apply_custom_logit_processor(
150
+ self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
151
+ ):
152
+ """Apply custom logit processors to the logits.
153
+ This function will modify the logits in-place."""
154
+
155
+ assert logits.shape[0] == len(sampling_batch_info), (
156
+ f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
157
+ f"sampling_batch_info ({len(sampling_batch_info)})"
158
+ )
159
+
160
+ for _, (
161
+ processor,
162
+ batch_mask,
163
+ ) in sampling_batch_info.custom_logit_processor.items():
164
+ # Get the batch indices that need to be processed
165
+ batch_indices = batch_mask.nonzero(as_tuple=True)[0]
166
+
167
+ assert batch_mask.shape[0] == len(sampling_batch_info), (
168
+ f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
169
+ f"sampling_batch_info ({len(sampling_batch_info)})"
170
+ )
171
+
172
+ # Apply the processor to the logits
173
+ logits[batch_mask] = processor(
174
+ logits[batch_mask],
175
+ [sampling_batch_info.custom_params[i] for i in batch_indices],
176
+ )
177
+
178
+ logger.debug(
179
+ f"Custom logit processor {processor.__class__.__name__} is applied."
180
+ )
123
181
 
124
182
 
125
183
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -5,6 +5,7 @@ Common utilities for torchao.
5
5
  import logging
6
6
  import os
7
7
  import pwd
8
+ from typing import Callable, Optional
8
9
 
9
10
  import torch
10
11
 
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
27
28
  return True
28
29
 
29
30
 
31
+ def proj_filter(
32
+ module: torch.nn.Module,
33
+ fqn: str,
34
+ ):
35
+ """Filter function for quantizing projection layers."""
36
+ return "proj" in fqn
37
+
38
+
30
39
  def apply_torchao_config_to_model(
31
- model: torch.nn.Module, torchao_config: str, filter_fn=None
40
+ model: torch.nn.Module,
41
+ torchao_config: str,
42
+ filter_fn: Optional[Callable] = proj_filter,
32
43
  ):
33
44
  """Quantize a modelwith torchao quantization specified by torchao_config
34
45
 
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
49
60
  )
50
61
  from torchao.quantization.observer import PerRow, PerTensor
51
62
 
52
- if filter_fn is None:
53
-
54
- def filter_fn(module, fqn):
55
- return "proj" in fqn
56
-
57
63
  if torchao_config == "" or torchao_config is None:
58
64
  return model
59
65
  elif "int8wo" in torchao_config:
@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
6
6
  import torch
7
7
  import torch.nn.functional as F
8
8
  from torch.nn.parameter import Parameter, UninitializedParameter
9
- from vllm.distributed import (
9
+
10
+ from sglang.srt.distributed import (
10
11
  divide,
11
12
  get_tensor_model_parallel_rank,
12
13
  get_tensor_model_parallel_world_size,
13
14
  tensor_model_parallel_all_reduce,
14
15
  )
15
-
16
16
  from sglang.srt.layers.parameter import BasevLLMParameter
17
17
  from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
sglang/srt/lora/lora.py CHANGED
@@ -19,18 +19,11 @@
19
19
  # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
20
20
 
21
21
 
22
- import json
23
- import os
24
22
  import re
25
- from typing import Any, Dict, List, Optional, Tuple
26
23
 
27
- import safetensors.torch
28
24
  import torch
29
25
  from torch import nn
30
- from vllm.model_executor.layers.vocab_parallel_embedding import (
31
- ParallelLMHead,
32
- VocabParallelEmbedding,
33
- )
26
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
27
 
35
28
  from sglang.srt.layers.linear import (
36
29
  ColumnParallelLinear,
@@ -38,7 +31,6 @@ from sglang.srt.layers.linear import (
38
31
  QKVParallelLinear,
39
32
  RowParallelLinear,
40
33
  )
41
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
42
34
  from sglang.srt.model_loader.loader import DefaultModelLoader
43
35
 
44
36
 
@@ -27,6 +27,7 @@ import requests
27
27
  if __name__ == "__main__":
28
28
  parser = argparse.ArgumentParser()
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ parser.add_argument("--log-requests", action="store_true")
30
31
  parser.add_argument(
31
32
  "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
32
33
  )
@@ -36,6 +37,8 @@ if __name__ == "__main__":
36
37
  response = requests.post(
37
38
  args.url + "/configure_logging",
38
39
  json={
40
+ "log_requests": args.log_requests,
41
+ "log_requests_level": 1, # Log full requests
39
42
  "dump_requests_folder": args.dump_requests_folder,
40
43
  "dump_requests_threshold": args.dump_requests_threshold,
41
44
  },
@@ -23,6 +23,7 @@ import psutil
23
23
  import setproctitle
24
24
  import zmq
25
25
 
26
+ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
26
27
  from sglang.srt.managers.io_struct import (
27
28
  TokenizedEmbeddingReqInput,
28
29
  TokenizedGenerateReqInput,
@@ -55,6 +56,7 @@ class DataParallelController:
55
56
 
56
57
  def __init__(self, server_args, port_args) -> None:
57
58
  # Parse args
59
+ self.max_total_num_tokens = None
58
60
  self.server_args = server_args
59
61
  self.port_args = port_args
60
62
  self.load_balance_method = LoadBalanceMethod.from_str(
@@ -63,9 +65,10 @@ class DataParallelController:
63
65
 
64
66
  # Init inter-process communication
65
67
  self.context = zmq.Context(1 + server_args.dp_size)
66
- self.recv_from_tokenizer = get_zmq_socket(
67
- self.context, zmq.PULL, port_args.scheduler_input_ipc_name
68
- )
68
+ if server_args.node_rank == 0:
69
+ self.recv_from_tokenizer = get_zmq_socket(
70
+ self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
71
+ )
69
72
 
70
73
  # Dispatch method
71
74
  self.round_robin_counter = 0
@@ -75,33 +78,50 @@ class DataParallelController:
75
78
  }
76
79
  self.dispatching = dispatch_lookup[self.load_balance_method]
77
80
 
78
- # Start data parallel workers
79
- base_gpu_id = 0
81
+ # Launch data parallel workers
82
+ self.scheduler_procs = []
80
83
  self.workers = [None] * server_args.dp_size
81
84
 
85
+ if not server_args.enable_dp_attention:
86
+ dp_port_args = self.launch_dp_schedulers(server_args, port_args)
87
+ else:
88
+ dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
89
+
90
+ # Only node rank 0 runs the real data parallel controller that dispatches the requests.
91
+ if server_args.node_rank == 0:
92
+ for dp_rank in range(server_args.dp_size):
93
+ self.workers[dp_rank] = get_zmq_socket(
94
+ self.context,
95
+ zmq.PUSH,
96
+ dp_port_args[dp_rank].scheduler_input_ipc_name,
97
+ True,
98
+ )
99
+
100
+ self.max_req_input_len = None
101
+
102
+ def launch_dp_schedulers(self, server_args, port_args):
103
+ base_gpu_id = 0
104
+
82
105
  threads = []
83
106
  sockets = []
107
+ dp_port_args = []
84
108
  for dp_rank in range(server_args.dp_size):
85
109
  tmp_port_args = PortArgs.init_new(server_args)
86
110
  tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
87
111
  tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
112
+ dp_port_args.append(tmp_port_args)
88
113
 
89
- if server_args.enable_dp_attention:
90
- # Data parallelism resues the tensor parallelism group,
91
- # so all dp ranks should use the same nccl port.
92
- tmp_port_args.nccl_port = port_args.nccl_port
93
- else:
94
- # This port is checked free in PortArgs.init_new.
95
- # We hold it first so that the next dp worker gets a different port
96
- sockets.append(bind_port(tmp_port_args.nccl_port))
114
+ # This port is checked free in PortArgs.init_new.
115
+ # We hold it first so that the next dp worker gets a different port
116
+ sockets.append(bind_port(tmp_port_args.nccl_port))
97
117
 
98
118
  # Create a thread for each worker
99
119
  thread = threading.Thread(
100
- target=self.launch_worker_func,
120
+ target=self.launch_tensor_parallel_group,
101
121
  args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
102
122
  )
103
123
  threads.append(thread)
104
- base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
124
+ base_gpu_id += server_args.tp_size
105
125
 
106
126
  # Free all sockets before starting the threads to launch TP workers
107
127
  for sock in sockets:
@@ -113,26 +133,14 @@ class DataParallelController:
113
133
  for thread in threads:
114
134
  thread.join()
115
135
 
116
- def launch_worker_func(
117
- self,
118
- server_args: ServerArgs,
119
- port_args: PortArgs,
120
- base_gpu_id: int,
121
- dp_rank: int,
122
- ):
123
- logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
136
+ return dp_port_args
124
137
 
125
- launch_func_ = (
126
- self.launch_tensor_parallel_process
127
- if server_args.enable_dp_attention
128
- else self.launch_tensor_parallel_group
129
- )
130
- self.workers[dp_rank] = launch_func_(
131
- server_args,
132
- port_args,
133
- base_gpu_id,
134
- dp_rank,
135
- )
138
+ def launch_dp_attention_schedulers(self, server_args, port_args):
139
+ self.launch_tensor_parallel_group(server_args, port_args, 0, None)
140
+ dp_port_args = []
141
+ for dp_rank in range(server_args.dp_size):
142
+ dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
143
+ return dp_port_args
136
144
 
137
145
  def launch_tensor_parallel_group(
138
146
  self,
@@ -141,8 +149,10 @@ class DataParallelController:
141
149
  base_gpu_id: int,
142
150
  dp_rank: int,
143
151
  ):
152
+ if not server_args.enable_dp_attention:
153
+ logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
154
+
144
155
  # Launch tensor parallel scheduler processes
145
- scheduler_procs = []
146
156
  scheduler_pipe_readers = []
147
157
  tp_size_per_node = server_args.tp_size // server_args.nnodes
148
158
  tp_rank_range = range(
@@ -150,52 +160,39 @@ class DataParallelController:
150
160
  tp_size_per_node * (server_args.node_rank + 1),
151
161
  )
152
162
  for tp_rank in tp_rank_range:
163
+ rank_port_args = port_args
164
+
165
+ if server_args.enable_dp_attention:
166
+ # dp attention has different sharding logic
167
+ _, _, dp_rank = compute_dp_attention_world_info(
168
+ server_args.enable_dp_attention,
169
+ tp_rank,
170
+ server_args.tp_size,
171
+ server_args.dp_size,
172
+ )
173
+ # compute zmq ports for this dp rank
174
+ rank_port_args = PortArgs.init_new(server_args, dp_rank)
175
+ # Data parallelism resues the tensor parallelism group,
176
+ # so all dp ranks should use the same nccl port.
177
+ rank_port_args.nccl_port = port_args.nccl_port
178
+
153
179
  reader, writer = mp.Pipe(duplex=False)
154
180
  gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
155
181
  proc = mp.Process(
156
182
  target=run_scheduler_process,
157
- args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
183
+ args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
158
184
  )
159
185
  proc.start()
160
- scheduler_procs.append(proc)
186
+ self.scheduler_procs.append(proc)
161
187
  scheduler_pipe_readers.append(reader)
162
188
 
163
- send_to = get_zmq_socket(
164
- self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
165
- )
166
-
167
- # Wait for model to finish loading and get max token nums
189
+ # Wait for model to finish loading
168
190
  scheduler_info = []
169
191
  for i in range(len(scheduler_pipe_readers)):
170
192
  scheduler_info.append(scheduler_pipe_readers[i].recv())
171
193
 
172
194
  self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
173
-
174
- return send_to
175
-
176
- def launch_tensor_parallel_process(
177
- self,
178
- server_args: ServerArgs,
179
- port_args: PortArgs,
180
- base_gpu_id: int,
181
- dp_rank: int,
182
- ):
183
- reader, writer = mp.Pipe(duplex=False)
184
- gpu_id = base_gpu_id
185
- tp_rank = dp_rank
186
- proc = mp.Process(
187
- target=run_scheduler_process,
188
- args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
189
- )
190
- proc.start()
191
- send_to = get_zmq_socket(
192
- self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
193
- )
194
-
195
- scheduler_info = reader.recv()
196
- self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
197
-
198
- return send_to
195
+ self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
199
196
 
200
197
  def round_robin_scheduler(self, req):
201
198
  self.workers[self.round_robin_counter].send_pyobj(req)
@@ -221,8 +218,8 @@ class DataParallelController:
221
218
  ):
222
219
  self.dispatching(recv_req)
223
220
  else:
224
- # Send other control messages to all workers
225
- for worker in self.workers:
221
+ # Send other control messages to first worker of tp group
222
+ for worker in self.workers[:: self.server_args.tp_size]:
226
223
  worker.send_pyobj(recv_req)
227
224
 
228
225
 
@@ -238,9 +235,19 @@ def run_data_parallel_controller_process(
238
235
  try:
239
236
  controller = DataParallelController(server_args, port_args)
240
237
  pipe_writer.send(
241
- {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
238
+ {
239
+ "status": "ready",
240
+ "max_total_num_tokens": controller.max_total_num_tokens,
241
+ "max_req_input_len": controller.max_req_input_len,
242
+ }
242
243
  )
243
- controller.event_loop()
244
+ if server_args.node_rank == 0:
245
+ controller.event_loop()
246
+ for proc in controller.scheduler_procs:
247
+ proc.join()
248
+ logger.error(
249
+ f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
250
+ )
244
251
  except Exception:
245
252
  traceback = get_exception_traceback()
246
253
  logger.error(f"DataParallelController hit an exception: {traceback}")
@@ -15,6 +15,7 @@
15
15
 
16
16
  import dataclasses
17
17
  import logging
18
+ import os
18
19
  import signal
19
20
  from collections import OrderedDict
20
21
  from typing import Dict, List, Union
@@ -35,6 +36,12 @@ from sglang.utils import find_printable_text, get_exception_traceback
35
36
 
36
37
  logger = logging.getLogger(__name__)
37
38
 
39
+ # Maximum number of request states that detokenizer can hold. When exceeded,
40
+ # oldest request states will be evicted. Default: 65536 (1<<16).
41
+ # For more details, see: https://github.com/sgl-project/sglang/issues/2812
42
+ # Use power of 2 values for better memory allocation.
43
+ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))
44
+
38
45
 
39
46
  @dataclasses.dataclass
40
47
  class DecodeStatus:
@@ -58,10 +65,10 @@ class DetokenizerManager:
58
65
  # Init inter-process communication
59
66
  context = zmq.Context(2)
60
67
  self.recv_from_scheduler = get_zmq_socket(
61
- context, zmq.PULL, port_args.detokenizer_ipc_name
68
+ context, zmq.PULL, port_args.detokenizer_ipc_name, True
62
69
  )
63
70
  self.send_to_tokenizer = get_zmq_socket(
64
- context, zmq.PUSH, port_args.tokenizer_ipc_name
71
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
65
72
  )
66
73
 
67
74
  if server_args.skip_tokenizer_init:
@@ -71,9 +78,10 @@ class DetokenizerManager:
71
78
  server_args.tokenizer_path,
72
79
  tokenizer_mode=server_args.tokenizer_mode,
73
80
  trust_remote_code=server_args.trust_remote_code,
81
+ revision=server_args.revision,
74
82
  )
75
83
 
76
- self.decode_status = LimitedCapacityDict()
84
+ self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
77
85
 
78
86
  def trim_matched_stop(
79
87
  self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
@@ -155,7 +163,17 @@ class DetokenizerManager:
155
163
  # Incremental decoding
156
164
  output_strs = []
157
165
  for i in range(bs):
158
- s = self.decode_status[recv_obj.rids[i]]
166
+ try:
167
+ s = self.decode_status[recv_obj.rids[i]]
168
+ except KeyError:
169
+ raise RuntimeError(
170
+ f"Decode status not found for request {recv_obj.rids[i]}. "
171
+ "It may be due to the request being evicted from the decode status due to memory pressure. "
172
+ "Please increase the maximum number of requests by setting "
173
+ "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
174
+ f"The current value is {DETOKENIZER_MAX_STATES}. "
175
+ "For more details, see: https://github.com/sgl-project/sglang/issues/2812"
176
+ )
159
177
  new_text = read_texts[i][len(surr_texts[i]) :]
160
178
  if recv_obj.finished_reasons[i] is None:
161
179
  # Streaming chunk: update the decode status
@@ -183,6 +201,7 @@ class DetokenizerManager:
183
201
  prompt_tokens=recv_obj.prompt_tokens,
184
202
  completion_tokens=recv_obj.completion_tokens,
185
203
  cached_tokens=recv_obj.cached_tokens,
204
+ spec_verify_ct=recv_obj.spec_verify_ct,
186
205
  input_token_logprobs_val=recv_obj.input_token_logprobs_val,
187
206
  input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
188
207
  output_token_logprobs_val=recv_obj.output_token_logprobs_val,
@@ -191,13 +210,12 @@ class DetokenizerManager:
191
210
  input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
192
211
  output_top_logprobs_val=recv_obj.output_top_logprobs_val,
193
212
  output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
194
- normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
195
213
  )
196
214
  )
197
215
 
198
216
 
199
217
  class LimitedCapacityDict(OrderedDict):
200
- def __init__(self, capacity=1 << 15, *args, **kwargs):
218
+ def __init__(self, capacity: int, *args, **kwargs):
201
219
  super().__init__(*args, **kwargs)
202
220
  self.capacity = capacity
203
221