sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Union
2
+ from typing import List
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
28
28
 
29
29
  def forward(
30
30
  self,
31
- logits: Union[torch.Tensor, LogitsProcessorOutput],
31
+ logits_output: LogitsProcessorOutput,
32
32
  sampling_info: SamplingBatchInfo,
33
+ return_logprob: bool,
34
+ top_logprobs_nums: List[int],
33
35
  ):
34
- if isinstance(logits, LogitsProcessorOutput):
35
- logits = logits.next_token_logits
36
-
37
- logits = logits.contiguous()
36
+ logits = logits_output.next_token_logits
38
37
 
39
38
  if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
40
39
  logger.warning("Detected errors during sampling! NaN in the logits.")
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
47
46
  if sampling_info.is_all_greedy:
48
47
  # Use torch.argmax if all requests use greedy sampling
49
48
  batch_next_token_ids = torch.argmax(logits, -1)
49
+ if return_logprob:
50
+ logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
50
51
  else:
51
52
  # Post process logits
52
53
  logits.div_(sampling_info.temperatures)
@@ -54,6 +55,14 @@ class Sampler(nn.Module):
54
55
  del logits
55
56
 
56
57
  if global_server_args_dict["sampling_backend"] == "flashinfer":
58
+ if return_logprob:
59
+ # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
60
+ # https://github.com/flashinfer-ai/flashinfer/issues/708
61
+ # so we use the torch implementation.
62
+ logprobs = torch.log(
63
+ top_p_normalize_probs_torch(probs, sampling_info.top_ps)
64
+ )
65
+
57
66
  max_top_k_round, batch_size = 32, probs.shape[0]
58
67
  uniform_samples = torch.rand(
59
68
  (max_top_k_round, batch_size), device=probs.device
@@ -76,6 +85,7 @@ class Sampler(nn.Module):
76
85
  if self.use_nan_detectioin and not torch.all(success):
77
86
  logger.warning("Detected errors during sampling!")
78
87
  batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
88
+
79
89
  elif global_server_args_dict["sampling_backend"] == "pytorch":
80
90
  # A slower fallback implementation with torch native operations.
81
91
  batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
@@ -85,12 +95,31 @@ class Sampler(nn.Module):
85
95
  sampling_info.min_ps,
86
96
  sampling_info.need_min_p_sampling,
87
97
  )
98
+ if return_logprob:
99
+ logprobs = torch.log(
100
+ top_p_normalize_probs_torch(probs, sampling_info.top_ps)
101
+ )
88
102
  else:
89
103
  raise ValueError(
90
104
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
91
105
  )
92
106
 
93
- return batch_next_token_ids.to(torch.int32)
107
+ batch_next_token_ids = batch_next_token_ids.to(torch.int32)
108
+
109
+ # Attach logprobs to logits_output (in-place modification)
110
+ if return_logprob:
111
+ if any(x > 0 for x in top_logprobs_nums):
112
+ (
113
+ logits_output.next_token_top_logprobs_val,
114
+ logits_output.next_token_top_logprobs_idx,
115
+ ) = get_top_logprobs(logprobs, top_logprobs_nums)
116
+
117
+ logits_output.next_token_logprobs = logprobs[
118
+ torch.arange(len(batch_next_token_ids), device=sampling_info.device),
119
+ batch_next_token_ids,
120
+ ]
121
+
122
+ return batch_next_token_ids
94
123
 
95
124
 
96
125
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -120,20 +149,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
120
149
  return batch_next_token_ids
121
150
 
122
151
 
123
- def top_p_normalize_probs(
152
+ def top_p_normalize_probs_torch(
124
153
  probs: torch.Tensor,
125
154
  top_ps: torch.Tensor,
126
155
  ):
127
- if global_server_args_dict["sampling_backend"] == "flashinfer":
128
- return top_p_renorm_prob(probs, top_ps)
129
- elif global_server_args_dict["sampling_backend"] == "pytorch":
130
- # See also top_k_top_p_min_p_sampling_from_probs_torch
131
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
132
- probs_sum = torch.cumsum(probs_sort, dim=-1)
133
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
134
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
135
- return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
136
- else:
137
- raise ValueError(
138
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
139
- )
156
+ # See also top_k_top_p_min_p_sampling_from_probs_torch
157
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
158
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
159
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
160
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
161
+ return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
162
+
163
+
164
+ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
165
+ max_k = max(top_logprobs_nums)
166
+ ret = logprobs.topk(max_k, dim=1)
167
+ values = ret.values.tolist()
168
+ indices = ret.indices.tolist()
169
+
170
+ output_top_logprobs_val = []
171
+ output_top_logprobs_idx = []
172
+ for i, k in enumerate(top_logprobs_nums):
173
+ output_top_logprobs_val.append(values[i][:k])
174
+ output_top_logprobs_idx.append(indices[i][:k])
175
+ return output_top_logprobs_val, output_top_logprobs_idx
@@ -11,6 +11,22 @@ import torch
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
13
 
14
+ def get_gemlite_cache_path() -> str:
15
+ return f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
16
+
17
+
18
+ def save_gemlite_cache(print_error: bool = False) -> bool:
19
+ try:
20
+ from gemlite.core import GemLiteLinearTriton
21
+
22
+ GemLiteLinearTriton.cache_config(get_gemlite_cache_path())
23
+ except Exception:
24
+ if print_error:
25
+ logger.error("Failed to save the GemLite cache.")
26
+ return False
27
+ return True
28
+
29
+
14
30
  def apply_torchao_config_to_model(
15
31
  model: torch.nn.Module, torchao_config: str, filter_fn=None
16
32
  ):
@@ -74,9 +90,7 @@ def apply_torchao_config_to_model(
74
90
  )
75
91
 
76
92
  # try to load gemlite kernel config
77
- GemLiteLinearTriton.load_config(
78
- f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
79
- )
93
+ GemLiteLinearTriton.load_config(get_gemlite_cache_path())
80
94
 
81
95
  elif "fp8wo" in torchao_config:
82
96
  # this requires newer hardware
@@ -12,8 +12,8 @@ from vllm.distributed import (
12
12
  get_tensor_model_parallel_world_size,
13
13
  tensor_model_parallel_all_reduce,
14
14
  )
15
- from vllm.model_executor.parameter import BasevLLMParameter
16
15
 
16
+ from sglang.srt.layers.parameter import BasevLLMParameter
17
17
  from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Copyright 2023-2025 SGLang Team
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ import logging
19
+ import threading
20
+ from queue import PriorityQueue, Queue
21
+ from typing import Optional
22
+
23
+ import torch
24
+
25
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class CacheOperation:
31
+
32
+ counter = 0
33
+
34
+ def __init__(
35
+ self,
36
+ host_indices: torch.Tensor,
37
+ device_indices: torch.Tensor,
38
+ node_id: int,
39
+ priority: Optional[int] = None,
40
+ ):
41
+ self.host_indices = host_indices
42
+ self.device_indices = device_indices
43
+ self.node_ids = [node_id]
44
+ self.data = None
45
+
46
+ self.id = CacheOperation.counter
47
+ CacheOperation.counter += 1
48
+ # default priority is the order of creation
49
+ self.priority = priority if priority is not None else self.id
50
+
51
+ def merge(self, other: "CacheOperation") -> None:
52
+ # multiple operations can be merged into a single operation for batch processing
53
+ self.host_indices = torch.cat([self.host_indices, other.host_indices])
54
+ self.device_indices = torch.cat([self.device_indices, other.device_indices])
55
+ self.priority = min(self.priority, other.priority)
56
+ self.node_ids.extend(other.node_ids)
57
+
58
+ def __lt__(self, other: "CacheOperation"):
59
+ return self.priority < other.priority
60
+
61
+
62
+ class TransferBuffer:
63
+ """
64
+ Overlapping buffer preparation and transfer operations to improve throughput.
65
+ """
66
+
67
+ def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
68
+ self.buffers = Queue(maxsize=buffer_count)
69
+ # todo: adjust the buffer size based on throughput profile of the system
70
+ self.max_buffer_size = max_buffer_size
71
+
72
+ def full(self) -> bool:
73
+ return self.buffers.full()
74
+
75
+ def empty(self) -> bool:
76
+ return self.buffers.empty()
77
+
78
+ def put(self, item, block=True) -> None:
79
+ self.buffers.put(item, block=block)
80
+
81
+ def get(self, block=True) -> Optional[CacheOperation]:
82
+ try:
83
+ return self.buffers.get(block=block)
84
+ except Exception as e:
85
+ logger.error(e)
86
+
87
+
88
+ class HiCacheController:
89
+
90
+ def __init__(
91
+ self,
92
+ mem_pool_device: MHATokenToKVPool,
93
+ mem_pool_host: MLATokenToKVPoolHost,
94
+ write_policy: str = "write_through_selective",
95
+ ):
96
+
97
+ self.mem_pool_device = mem_pool_device
98
+ self.mem_pool_host = mem_pool_host
99
+ self.write_policy = write_policy
100
+
101
+ if write_policy not in [
102
+ "write_through",
103
+ "write_through_selective",
104
+ "write_back",
105
+ ]:
106
+ raise ValueError(f"Invalid write policy: {write_policy}")
107
+
108
+ self.write_queue = PriorityQueue()
109
+ self.load_queue = PriorityQueue()
110
+
111
+ self.ack_write_queue = Queue()
112
+ self.ack_load_queue = Queue()
113
+
114
+ self.write_buffer = TransferBuffer()
115
+ self.load_buffer = TransferBuffer()
116
+
117
+ self.write_stream = torch.cuda.Stream()
118
+ self.load_stream = torch.cuda.Stream()
119
+
120
+ self.write_thread = threading.Thread(
121
+ target=self.write_thread_func_buffer, daemon=True
122
+ )
123
+ self.load_thread = threading.Thread(
124
+ target=self.load_thread_func_buffer, daemon=True
125
+ )
126
+ self.write_thread.start()
127
+ self.load_thread.start()
128
+
129
+ def write(
130
+ self,
131
+ device_indices: torch.Tensor,
132
+ priority: Optional[int] = None,
133
+ node_id: int = 0,
134
+ ) -> Optional[torch.Tensor]:
135
+ """
136
+ Back up KV caches from device memory to host memory.
137
+ """
138
+ host_indices = self.mem_pool_host.alloc(len(device_indices))
139
+ if host_indices is None:
140
+ return None
141
+ self.write_queue.put(
142
+ CacheOperation(host_indices, device_indices, node_id, priority)
143
+ )
144
+ self.mem_pool_host.protect_write(host_indices)
145
+ return host_indices
146
+
147
+ def load(
148
+ self,
149
+ host_indices: torch.Tensor,
150
+ priority: Optional[int] = None,
151
+ node_id: int = 0,
152
+ ) -> Optional[torch.Tensor]:
153
+ """
154
+ Load KV caches from host memory to device memory.
155
+ """
156
+ device_indices = self.mem_pool_device.alloc(len(host_indices))
157
+ if device_indices is None:
158
+ return None
159
+ self.load_queue.put(
160
+ CacheOperation(host_indices, device_indices, node_id, priority)
161
+ )
162
+ self.mem_pool_host.protect_load(host_indices)
163
+ return device_indices
164
+
165
+ def write_thread_func_direct(self):
166
+ """
167
+ Directly write through KV caches to host memory without buffering.
168
+ """
169
+ with torch.cuda.stream(self.write_stream):
170
+ while True:
171
+ try:
172
+ operation = self.write_queue.get(block=True)
173
+ operation.data = self.mem_pool_device.get_flat_data(
174
+ operation.device_indices
175
+ )
176
+ self.mem_pool_host.transfer(operation.host_indices, operation.data)
177
+ self.mem_pool_host.complete_io(operation.host_indices)
178
+ for node_id in operation.node_ids:
179
+ self.ack_write_queue.put(node_id)
180
+ except Exception as e:
181
+ logger.error(e)
182
+
183
+ def load_thread_func_direct(self):
184
+ """
185
+ Directly load KV caches from host memory to device memory without buffering.
186
+ """
187
+ with torch.cuda.stream(self.load_stream):
188
+ while True:
189
+ try:
190
+ operation = self.load_queue.get(block=True)
191
+ operation.data = self.mem_pool_host.get_flat_data(
192
+ operation.host_indices
193
+ )
194
+ self.mem_pool_device.transfer(
195
+ operation.device_indices, operation.data
196
+ )
197
+ self.mem_pool_host.complete_io(operation.host_indices)
198
+ for node_id in operation.node_ids:
199
+ self.ack_load_queue.put(node_id)
200
+ except Exception as e:
201
+ logger.error(e)
202
+
203
+ def write_aux_func(self, no_wait=False):
204
+ """
205
+ Auxiliary function to prepare the buffer for write operations.
206
+ """
207
+ buffer = None
208
+ while True:
209
+ try:
210
+ operation = self.write_queue.get(block=True)
211
+ if buffer is None:
212
+ buffer = operation
213
+ else:
214
+ buffer.merge(operation)
215
+ if (
216
+ no_wait
217
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
218
+ or self.write_queue.empty()
219
+ or self.write_buffer.empty()
220
+ ):
221
+ assert (
222
+ buffer.device_indices.is_cuda
223
+ ), "Device indices should be on GPU"
224
+ buffer.data = self.mem_pool_device.get_flat_data(
225
+ buffer.device_indices
226
+ ).contiguous()
227
+ self.write_buffer.put(buffer, block=True)
228
+ buffer = None
229
+ except Exception as e:
230
+ logger.error(e)
231
+
232
+ def load_aux_func(self):
233
+ """
234
+ Auxiliary function to prepare the buffer for load operations.
235
+ """
236
+ buffer = None
237
+ while True:
238
+ try:
239
+ operation = self.load_queue.get(block=True)
240
+ if buffer is None:
241
+ buffer = operation
242
+ else:
243
+ buffer.merge(operation)
244
+ if (
245
+ len(buffer.host_indices) >= self.load_buffer.max_buffer_size
246
+ or self.load_queue.empty()
247
+ or self.load_buffer.empty()
248
+ ):
249
+ buffer.data = (
250
+ self.mem_pool_host.get_flat_data(buffer.host_indices)
251
+ .contiguous()
252
+ .pin_memory()
253
+ )
254
+ self.load_buffer.put(buffer, block=True)
255
+ buffer = None
256
+ except Exception as e:
257
+ logger.error(e)
258
+
259
+ def write_thread_func_buffer(self):
260
+ aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
261
+ aux_thread.start()
262
+ with torch.cuda.stream(self.write_stream):
263
+ while True:
264
+ operation = self.write_buffer.get()
265
+ if operation is None:
266
+ continue
267
+ self.mem_pool_host.transfer(operation.host_indices, operation.data)
268
+ self.mem_pool_host.complete_io(operation.host_indices)
269
+ for node_id in operation.node_ids:
270
+ self.ack_write_queue.put(node_id)
271
+
272
+ def load_thread_func_buffer(self):
273
+ aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
274
+ aux_thread.start()
275
+ with torch.cuda.stream(self.load_stream):
276
+ while True:
277
+ operation = self.load_buffer.get()
278
+ if operation is None:
279
+ continue
280
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
281
+ self.mem_pool_host.complete_io(operation.host_indices)
282
+ for node_id in operation.node_ids:
283
+ self.ack_load_queue.put(node_id)
284
+
285
+ def evict_device(
286
+ self, device_indices: torch.Tensor, host_indices: torch.Tensor
287
+ ) -> int:
288
+ if self.mem_pool_host.is_synced(host_indices):
289
+ self.mem_pool_device.free(device_indices)
290
+ self.mem_pool_host.update_backup(host_indices)
291
+ return len(device_indices)
292
+ else:
293
+ raise ValueError(
294
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
295
+ )
296
+
297
+ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
298
+ if not backup_only:
299
+ raise ValueError("Other eviction policies are not supported yet.")
300
+
301
+ if self.mem_pool_host.is_backup(host_indices):
302
+ self.mem_pool_host.free(host_indices)
303
+ return len(host_indices)
304
+ else:
305
+ raise ValueError(
306
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
307
+ )
@@ -20,6 +20,7 @@ import threading
20
20
  from enum import Enum, auto
21
21
 
22
22
  import psutil
23
+ import setproctitle
23
24
  import zmq
24
25
 
25
26
  from sglang.srt.managers.io_struct import (
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
230
231
  port_args: PortArgs,
231
232
  pipe_writer,
232
233
  ):
234
+ setproctitle.setproctitle("sglang::data_parallel_controller")
233
235
  configure_logger(server_args)
234
236
  parent_process = psutil.Process().parent()
235
237
 
@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
426
426
 
427
427
  @dataclass
428
428
  class UpdateWeightsFromTensorReqInput:
429
- name: str
430
- tensor: torch.Tensor
429
+ serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
431
430
 
432
431
 
433
432
  @dataclass
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # Copyright 2023-2024 SGLang Team
2
4
  # Licensed under the Apache License, Version 2.0 (the "License");
3
5
  # you may not use this file except in compliance with the License.
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
29
31
 
30
32
  import dataclasses
31
33
  import logging
32
- from typing import List, Optional, Set, Tuple, Union
34
+ from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
33
35
 
34
36
  import numpy as np
35
37
  import torch
@@ -42,11 +44,15 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
42
44
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
45
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
46
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
45
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
46
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
47
49
  from sglang.srt.sampling.sampling_params import SamplingParams
48
50
  from sglang.srt.server_args import ServerArgs
49
51
 
52
+ if TYPE_CHECKING:
53
+ from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
54
+
55
+
50
56
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
51
57
 
52
58
  # Put some global args for easy access
@@ -565,9 +571,13 @@ class ScheduleBatch:
565
571
  # Has grammar
566
572
  has_grammar: bool = False
567
573
 
568
- # device
574
+ # Device
569
575
  device: str = "cuda"
570
576
 
577
+ # Speculative decoding
578
+ spec_algorithm: SpeculativeAlgorithm = None
579
+ spec_info: Optional[SpecInfo] = None
580
+
571
581
  @classmethod
572
582
  def init_new(
573
583
  cls,
@@ -577,6 +587,7 @@ class ScheduleBatch:
577
587
  tree_cache: BasePrefixCache,
578
588
  model_config: ModelConfig,
579
589
  enable_overlap: bool,
590
+ spec_algorithm: SpeculativeAlgorithm,
580
591
  ):
581
592
  return cls(
582
593
  reqs=reqs,
@@ -589,6 +600,7 @@ class ScheduleBatch:
589
600
  has_stream=any(req.stream for req in reqs),
590
601
  has_grammar=any(req.grammar for req in reqs),
591
602
  device=req_to_token_pool.device,
603
+ spec_algorithm=spec_algorithm,
592
604
  )
593
605
 
594
606
  def batch_size(self):
@@ -998,6 +1010,8 @@ class ScheduleBatch:
998
1010
 
999
1011
  def prepare_for_decode(self):
1000
1012
  self.forward_mode = ForwardMode.DECODE
1013
+ if self.spec_algorithm.is_eagle():
1014
+ return
1001
1015
 
1002
1016
  self.input_ids = self.output_ids
1003
1017
  self.output_ids = None
@@ -1103,6 +1117,9 @@ class ScheduleBatch:
1103
1117
  self.has_stream |= other.has_stream
1104
1118
  self.has_grammar |= other.has_grammar
1105
1119
 
1120
+ if self.spec_info:
1121
+ self.spec_info.merge_batch(other.spec_info)
1122
+
1106
1123
  def get_model_worker_batch(self):
1107
1124
  if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1108
1125
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
@@ -1144,6 +1161,13 @@ class ScheduleBatch:
1144
1161
  lora_paths=[req.lora_path for req in self.reqs],
1145
1162
  sampling_info=self.sampling_info,
1146
1163
  input_embeds=self.input_embeds,
1164
+ spec_algorithm=self.spec_algorithm,
1165
+ spec_info=self.spec_info,
1166
+ capture_hidden_mode=(
1167
+ getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
1168
+ if self.spec_info
1169
+ else CaptureHiddenMode.NULL
1170
+ ),
1147
1171
  )
1148
1172
 
1149
1173
  def copy(self):
@@ -1155,6 +1179,7 @@ class ScheduleBatch:
1155
1179
  out_cache_loc=self.out_cache_loc,
1156
1180
  return_logprob=self.return_logprob,
1157
1181
  decoding_reqs=self.decoding_reqs,
1182
+ spec_algorithm=self.spec_algorithm,
1158
1183
  )
1159
1184
 
1160
1185
  def __str__(self):
@@ -1214,6 +1239,11 @@ class ModelWorkerBatch:
1214
1239
  # The input Embeds
1215
1240
  input_embeds: Optional[torch.tensor] = None
1216
1241
 
1242
+ # Speculative decoding
1243
+ spec_algorithm: SpeculativeAlgorithm = None
1244
+ spec_info: Optional[SpecInfo] = None
1245
+ capture_hidden_mode: CaptureHiddenMode = None
1246
+
1217
1247
 
1218
1248
  @triton.jit
1219
1249
  def write_req_to_token_pool_triton(