sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,12 @@
1
1
  import logging
2
- from typing import List
2
+ from typing import Dict, List
3
3
 
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
+ from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
9
10
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
11
  from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
11
12
 
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
35
36
  ):
36
37
  logits = logits_output.next_token_logits
37
38
 
39
+ # Apply the custom logit processors if registered in the sampling info.
40
+ if sampling_info.has_custom_logit_processor:
41
+ self._apply_custom_logit_processor(logits, sampling_info)
42
+
38
43
  if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
39
44
  logger.warning("Detected errors during sampling! NaN in the logits.")
40
45
  logits = torch.where(
@@ -121,6 +126,39 @@ class Sampler(nn.Module):
121
126
 
122
127
  return batch_next_token_ids
123
128
 
129
+ def _apply_custom_logit_processor(
130
+ self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
131
+ ):
132
+ """Apply custom logit processors to the logits.
133
+ This function will modify the logits in-place."""
134
+
135
+ assert logits.shape[0] == len(sampling_batch_info), (
136
+ f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
137
+ f"sampling_batch_info ({len(sampling_batch_info)})"
138
+ )
139
+
140
+ for _, (
141
+ processor,
142
+ batch_mask,
143
+ ) in sampling_batch_info.custom_logit_processor.items():
144
+ # Get the batch indices that need to be processed
145
+ batch_indices = batch_mask.nonzero(as_tuple=True)[0]
146
+
147
+ assert batch_mask.shape[0] == len(sampling_batch_info), (
148
+ f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
149
+ f"sampling_batch_info ({len(sampling_batch_info)})"
150
+ )
151
+
152
+ # Apply the processor to the logits
153
+ logits[batch_mask] = processor(
154
+ logits[batch_mask],
155
+ [sampling_batch_info.custom_params[i] for i in batch_indices],
156
+ )
157
+
158
+ logger.debug(
159
+ f"Custom logit processor {processor.__class__.__name__} is applied."
160
+ )
161
+
124
162
 
125
163
  def top_k_top_p_min_p_sampling_from_probs_torch(
126
164
  probs: torch.Tensor,
@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
6
6
  import torch
7
7
  import torch.nn.functional as F
8
8
  from torch.nn.parameter import Parameter, UninitializedParameter
9
- from vllm.distributed import (
9
+
10
+ from sglang.srt.distributed import (
10
11
  divide,
11
12
  get_tensor_model_parallel_rank,
12
13
  get_tensor_model_parallel_world_size,
13
14
  tensor_model_parallel_all_reduce,
14
15
  )
15
-
16
16
  from sglang.srt.layers.parameter import BasevLLMParameter
17
17
  from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
220
220
  quant_config: Optional[QuantizationConfig] = None,
221
221
  prefix: str = "",
222
222
  enable_tp: bool = True,
223
+ use_presharded_weights: bool = False,
223
224
  ):
224
225
  super().__init__()
225
226
  self.quant_config = quant_config
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
236
237
  self.padding_size = padding_size
237
238
  self.org_vocab_size = org_num_embeddings or num_embeddings
238
239
  num_added_embeddings = num_embeddings - self.org_vocab_size
240
+ self.use_presharded_weights = use_presharded_weights
241
+ if use_presharded_weights:
242
+ assert (
243
+ num_added_embeddings == 0
244
+ ), "Lora is not supported with presharded weights."
245
+
239
246
  self.org_vocab_size_padded = pad_vocab_size(
240
247
  self.org_vocab_size, self.padding_size
241
248
  )
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
447
454
  start_idx = start_idx // packed_factor
448
455
  shard_size = shard_size // packed_factor
449
456
  else:
450
- assert loaded_weight.shape[output_dim] == self.org_vocab_size
457
+ assert loaded_weight.shape[output_dim] == (
458
+ self.org_vocab_size
459
+ // (self.tp_size if self.use_presharded_weights else 1)
460
+ )
451
461
 
452
462
  # Copy the data.
453
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
463
+ if not self.use_presharded_weights:
464
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
454
465
  param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
455
466
  param[loaded_weight.shape[0] :].data.fill_(0)
456
467
 
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
514
525
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
515
526
  quant_config: Optional[QuantizationConfig] = None,
516
527
  prefix: str = "",
528
+ use_presharded_weights: bool = False,
517
529
  ):
518
530
  super().__init__(
519
531
  num_embeddings,
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
523
535
  padding_size,
524
536
  quant_config,
525
537
  prefix,
538
+ use_presharded_weights=use_presharded_weights,
526
539
  )
527
540
  self.quant_config = quant_config
528
541
  if bias:
sglang/srt/lora/lora.py CHANGED
@@ -19,18 +19,11 @@
19
19
  # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
20
20
 
21
21
 
22
- import json
23
- import os
24
22
  import re
25
- from typing import Any, Dict, List, Optional, Tuple
26
23
 
27
- import safetensors.torch
28
24
  import torch
29
25
  from torch import nn
30
- from vllm.model_executor.layers.vocab_parallel_embedding import (
31
- ParallelLMHead,
32
- VocabParallelEmbedding,
33
- )
26
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
27
 
35
28
  from sglang.srt.layers.linear import (
36
29
  ColumnParallelLinear,
@@ -38,7 +31,6 @@ from sglang.srt.layers.linear import (
38
31
  QKVParallelLinear,
39
32
  RowParallelLinear,
40
33
  )
41
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
42
34
  from sglang.srt.model_loader.loader import DefaultModelLoader
43
35
 
44
36
 
@@ -0,0 +1,46 @@
1
+ """
2
+ Copyright 2023-2025 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Configure the logging settings of a server.
18
+
19
+ Usage:
20
+ python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
21
+ """
22
+
23
+ import argparse
24
+
25
+ import requests
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ parser.add_argument("--log-requests", action="store_true")
31
+ parser.add_argument(
32
+ "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
33
+ )
34
+ parser.add_argument("--dump-requests-threshold", type=int, default=1000)
35
+ args = parser.parse_args()
36
+
37
+ response = requests.post(
38
+ args.url + "/configure_logging",
39
+ json={
40
+ "log_requests": args.log_requests,
41
+ "log_requests_level": 1, # Log full requests
42
+ "dump_requests_folder": args.dump_requests_folder,
43
+ "dump_requests_threshold": args.dump_requests_threshold,
44
+ },
45
+ )
46
+ assert response.status_code == 200
@@ -23,6 +23,7 @@ import psutil
23
23
  import setproctitle
24
24
  import zmq
25
25
 
26
+ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
26
27
  from sglang.srt.managers.io_struct import (
27
28
  TokenizedEmbeddingReqInput,
28
29
  TokenizedGenerateReqInput,
@@ -55,6 +56,7 @@ class DataParallelController:
55
56
 
56
57
  def __init__(self, server_args, port_args) -> None:
57
58
  # Parse args
59
+ self.max_total_num_tokens = None
58
60
  self.server_args = server_args
59
61
  self.port_args = port_args
60
62
  self.load_balance_method = LoadBalanceMethod.from_str(
@@ -63,9 +65,10 @@ class DataParallelController:
63
65
 
64
66
  # Init inter-process communication
65
67
  self.context = zmq.Context(1 + server_args.dp_size)
66
- self.recv_from_tokenizer = get_zmq_socket(
67
- self.context, zmq.PULL, port_args.scheduler_input_ipc_name
68
- )
68
+ if server_args.node_rank == 0:
69
+ self.recv_from_tokenizer = get_zmq_socket(
70
+ self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
71
+ )
69
72
 
70
73
  # Dispatch method
71
74
  self.round_robin_counter = 0
@@ -75,33 +78,50 @@ class DataParallelController:
75
78
  }
76
79
  self.dispatching = dispatch_lookup[self.load_balance_method]
77
80
 
78
- # Start data parallel workers
79
- base_gpu_id = 0
81
+ # Launch data parallel workers
82
+ self.scheduler_procs = []
80
83
  self.workers = [None] * server_args.dp_size
81
84
 
85
+ if not server_args.enable_dp_attention:
86
+ dp_port_args = self.launch_dp_schedulers(server_args, port_args)
87
+ else:
88
+ dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
89
+
90
+ # Only node rank 0 runs the real data parallel controller that dispatches the requests.
91
+ if server_args.node_rank == 0:
92
+ for dp_rank in range(server_args.dp_size):
93
+ self.workers[dp_rank] = get_zmq_socket(
94
+ self.context,
95
+ zmq.PUSH,
96
+ dp_port_args[dp_rank].scheduler_input_ipc_name,
97
+ True,
98
+ )
99
+
100
+ self.max_req_input_len = None
101
+
102
+ def launch_dp_schedulers(self, server_args, port_args):
103
+ base_gpu_id = 0
104
+
82
105
  threads = []
83
106
  sockets = []
107
+ dp_port_args = []
84
108
  for dp_rank in range(server_args.dp_size):
85
109
  tmp_port_args = PortArgs.init_new(server_args)
86
110
  tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
87
111
  tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
112
+ dp_port_args.append(tmp_port_args)
88
113
 
89
- if server_args.enable_dp_attention:
90
- # Data parallelism resues the tensor parallelism group,
91
- # so all dp ranks should use the same nccl port.
92
- tmp_port_args.nccl_port = port_args.nccl_port
93
- else:
94
- # This port is checked free in PortArgs.init_new.
95
- # We hold it first so that the next dp worker gets a different port
96
- sockets.append(bind_port(tmp_port_args.nccl_port))
114
+ # This port is checked free in PortArgs.init_new.
115
+ # We hold it first so that the next dp worker gets a different port
116
+ sockets.append(bind_port(tmp_port_args.nccl_port))
97
117
 
98
118
  # Create a thread for each worker
99
119
  thread = threading.Thread(
100
- target=self.launch_worker_func,
120
+ target=self.launch_tensor_parallel_group,
101
121
  args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
102
122
  )
103
123
  threads.append(thread)
104
- base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
124
+ base_gpu_id += server_args.tp_size
105
125
 
106
126
  # Free all sockets before starting the threads to launch TP workers
107
127
  for sock in sockets:
@@ -113,26 +133,14 @@ class DataParallelController:
113
133
  for thread in threads:
114
134
  thread.join()
115
135
 
116
- def launch_worker_func(
117
- self,
118
- server_args: ServerArgs,
119
- port_args: PortArgs,
120
- base_gpu_id: int,
121
- dp_rank: int,
122
- ):
123
- logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
136
+ return dp_port_args
124
137
 
125
- launch_func_ = (
126
- self.launch_tensor_parallel_process
127
- if server_args.enable_dp_attention
128
- else self.launch_tensor_parallel_group
129
- )
130
- self.workers[dp_rank] = launch_func_(
131
- server_args,
132
- port_args,
133
- base_gpu_id,
134
- dp_rank,
135
- )
138
+ def launch_dp_attention_schedulers(self, server_args, port_args):
139
+ self.launch_tensor_parallel_group(server_args, port_args, 0, None)
140
+ dp_port_args = []
141
+ for dp_rank in range(server_args.dp_size):
142
+ dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
143
+ return dp_port_args
136
144
 
137
145
  def launch_tensor_parallel_group(
138
146
  self,
@@ -141,8 +149,10 @@ class DataParallelController:
141
149
  base_gpu_id: int,
142
150
  dp_rank: int,
143
151
  ):
152
+ if not server_args.enable_dp_attention:
153
+ logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
154
+
144
155
  # Launch tensor parallel scheduler processes
145
- scheduler_procs = []
146
156
  scheduler_pipe_readers = []
147
157
  tp_size_per_node = server_args.tp_size // server_args.nnodes
148
158
  tp_rank_range = range(
@@ -150,52 +160,39 @@ class DataParallelController:
150
160
  tp_size_per_node * (server_args.node_rank + 1),
151
161
  )
152
162
  for tp_rank in tp_rank_range:
163
+ rank_port_args = port_args
164
+
165
+ if server_args.enable_dp_attention:
166
+ # dp attention has different sharding logic
167
+ _, _, dp_rank = compute_dp_attention_world_info(
168
+ server_args.enable_dp_attention,
169
+ tp_rank,
170
+ server_args.tp_size,
171
+ server_args.dp_size,
172
+ )
173
+ # compute zmq ports for this dp rank
174
+ rank_port_args = PortArgs.init_new(server_args, dp_rank)
175
+ # Data parallelism resues the tensor parallelism group,
176
+ # so all dp ranks should use the same nccl port.
177
+ rank_port_args.nccl_port = port_args.nccl_port
178
+
153
179
  reader, writer = mp.Pipe(duplex=False)
154
180
  gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
155
181
  proc = mp.Process(
156
182
  target=run_scheduler_process,
157
- args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
183
+ args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
158
184
  )
159
185
  proc.start()
160
- scheduler_procs.append(proc)
186
+ self.scheduler_procs.append(proc)
161
187
  scheduler_pipe_readers.append(reader)
162
188
 
163
- send_to = get_zmq_socket(
164
- self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
165
- )
166
-
167
- # Wait for model to finish loading and get max token nums
189
+ # Wait for model to finish loading
168
190
  scheduler_info = []
169
191
  for i in range(len(scheduler_pipe_readers)):
170
192
  scheduler_info.append(scheduler_pipe_readers[i].recv())
171
193
 
172
194
  self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
173
-
174
- return send_to
175
-
176
- def launch_tensor_parallel_process(
177
- self,
178
- server_args: ServerArgs,
179
- port_args: PortArgs,
180
- base_gpu_id: int,
181
- dp_rank: int,
182
- ):
183
- reader, writer = mp.Pipe(duplex=False)
184
- gpu_id = base_gpu_id
185
- tp_rank = dp_rank
186
- proc = mp.Process(
187
- target=run_scheduler_process,
188
- args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
189
- )
190
- proc.start()
191
- send_to = get_zmq_socket(
192
- self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
193
- )
194
-
195
- scheduler_info = reader.recv()
196
- self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
197
-
198
- return send_to
195
+ self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
199
196
 
200
197
  def round_robin_scheduler(self, req):
201
198
  self.workers[self.round_robin_counter].send_pyobj(req)
@@ -221,8 +218,8 @@ class DataParallelController:
221
218
  ):
222
219
  self.dispatching(recv_req)
223
220
  else:
224
- # Send other control messages to all workers
225
- for worker in self.workers:
221
+ # Send other control messages to first worker of tp group
222
+ for worker in self.workers[:: self.server_args.tp_size]:
226
223
  worker.send_pyobj(recv_req)
227
224
 
228
225
 
@@ -238,9 +235,19 @@ def run_data_parallel_controller_process(
238
235
  try:
239
236
  controller = DataParallelController(server_args, port_args)
240
237
  pipe_writer.send(
241
- {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
238
+ {
239
+ "status": "ready",
240
+ "max_total_num_tokens": controller.max_total_num_tokens,
241
+ "max_req_input_len": controller.max_req_input_len,
242
+ }
242
243
  )
243
- controller.event_loop()
244
+ if server_args.node_rank == 0:
245
+ controller.event_loop()
246
+ for proc in controller.scheduler_procs:
247
+ proc.join()
248
+ logger.error(
249
+ f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
250
+ )
244
251
  except Exception:
245
252
  traceback = get_exception_traceback()
246
253
  logger.error(f"DataParallelController hit an exception: {traceback}")
@@ -15,6 +15,7 @@
15
15
 
16
16
  import dataclasses
17
17
  import logging
18
+ import os
18
19
  import signal
19
20
  from collections import OrderedDict
20
21
  from typing import Dict, List, Union
@@ -35,6 +36,12 @@ from sglang.utils import find_printable_text, get_exception_traceback
35
36
 
36
37
  logger = logging.getLogger(__name__)
37
38
 
39
+ # Maximum number of request states that detokenizer can hold. When exceeded,
40
+ # oldest request states will be evicted. Default: 65536 (1<<16).
41
+ # For more details, see: https://github.com/sgl-project/sglang/issues/2812
42
+ # Use power of 2 values for better memory allocation.
43
+ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))
44
+
38
45
 
39
46
  @dataclasses.dataclass
40
47
  class DecodeStatus:
@@ -58,10 +65,10 @@ class DetokenizerManager:
58
65
  # Init inter-process communication
59
66
  context = zmq.Context(2)
60
67
  self.recv_from_scheduler = get_zmq_socket(
61
- context, zmq.PULL, port_args.detokenizer_ipc_name
68
+ context, zmq.PULL, port_args.detokenizer_ipc_name, True
62
69
  )
63
70
  self.send_to_tokenizer = get_zmq_socket(
64
- context, zmq.PUSH, port_args.tokenizer_ipc_name
71
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
65
72
  )
66
73
 
67
74
  if server_args.skip_tokenizer_init:
@@ -71,9 +78,10 @@ class DetokenizerManager:
71
78
  server_args.tokenizer_path,
72
79
  tokenizer_mode=server_args.tokenizer_mode,
73
80
  trust_remote_code=server_args.trust_remote_code,
81
+ revision=server_args.revision,
74
82
  )
75
83
 
76
- self.decode_status = LimitedCapacityDict()
84
+ self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
77
85
 
78
86
  def trim_matched_stop(
79
87
  self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
@@ -155,7 +163,17 @@ class DetokenizerManager:
155
163
  # Incremental decoding
156
164
  output_strs = []
157
165
  for i in range(bs):
158
- s = self.decode_status[recv_obj.rids[i]]
166
+ try:
167
+ s = self.decode_status[recv_obj.rids[i]]
168
+ except KeyError:
169
+ raise RuntimeError(
170
+ f"Decode status not found for request {recv_obj.rids[i]}. "
171
+ "It may be due to the request being evicted from the decode status due to memory pressure. "
172
+ "Please increase the maximum number of requests by setting "
173
+ "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
174
+ f"The current value is {DETOKENIZER_MAX_STATES}. "
175
+ "For more details, see: https://github.com/sgl-project/sglang/issues/2812"
176
+ )
159
177
  new_text = read_texts[i][len(surr_texts[i]) :]
160
178
  if recv_obj.finished_reasons[i] is None:
161
179
  # Streaming chunk: update the decode status
@@ -181,8 +199,6 @@ class DetokenizerManager:
181
199
  finished_reasons=recv_obj.finished_reasons,
182
200
  output_strs=output_strs,
183
201
  prompt_tokens=recv_obj.prompt_tokens,
184
- origin_input_ids=recv_obj.origin_input_ids,
185
- output_ids=recv_obj.output_ids,
186
202
  completion_tokens=recv_obj.completion_tokens,
187
203
  cached_tokens=recv_obj.cached_tokens,
188
204
  input_token_logprobs_val=recv_obj.input_token_logprobs_val,
@@ -193,13 +209,12 @@ class DetokenizerManager:
193
209
  input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
194
210
  output_top_logprobs_val=recv_obj.output_top_logprobs_val,
195
211
  output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
196
- normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
197
212
  )
198
213
  )
199
214
 
200
215
 
201
216
  class LimitedCapacityDict(OrderedDict):
202
- def __init__(self, capacity=1 << 15, *args, **kwargs):
217
+ def __init__(self, capacity: int, *args, **kwargs):
203
218
  super().__init__(*args, **kwargs)
204
219
  self.capacity = capacity
205
220