sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +46 -25
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +184 -63
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -248
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/model_executor/cuda_graph_runner.py +15 -19
- sglang/srt/model_executor/forward_batch_info.py +94 -95
- sglang/srt/model_executor/model_runner.py +76 -75
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +14 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +71 -26
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +18 -9
- sglang/version.py +1 -1
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -474
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.2.dist-info/RECORD +0 -135
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,474 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
"""
|
4
|
-
Support different attention backends.
|
5
|
-
Now there are two backends: FlashInfer and Triton.
|
6
|
-
FlashInfer is faster and Triton is easier to customize.
|
7
|
-
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
|
-
"""
|
9
|
-
|
10
|
-
from abc import ABC, abstractmethod
|
11
|
-
from typing import TYPE_CHECKING
|
12
|
-
|
13
|
-
import torch
|
14
|
-
import torch.nn as nn
|
15
|
-
|
16
|
-
from sglang.global_config import global_config
|
17
|
-
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
18
|
-
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
19
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
20
|
-
from sglang.srt.utils import is_hip
|
21
|
-
|
22
|
-
if TYPE_CHECKING:
|
23
|
-
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
|
-
|
25
|
-
# ROCm: flashinfer available later
|
26
|
-
if not is_hip():
|
27
|
-
from flashinfer import (
|
28
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
29
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
30
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
31
|
-
)
|
32
|
-
from flashinfer.cascade import merge_state
|
33
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
34
|
-
|
35
|
-
|
36
|
-
class AttentionBackend(ABC):
|
37
|
-
"""The base class of attention backends"""
|
38
|
-
|
39
|
-
@abstractmethod
|
40
|
-
def init_forward_metadata(
|
41
|
-
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
42
|
-
):
|
43
|
-
"""Init the metadata for a forward pass."""
|
44
|
-
raise NotImplementedError()
|
45
|
-
|
46
|
-
def init_cuda_graph_state(self, max_bs: int):
|
47
|
-
"""Init the global shared states for cuda graph."""
|
48
|
-
raise NotImplementedError()
|
49
|
-
|
50
|
-
def init_forward_metadata_capture_cuda_graph(
|
51
|
-
self, bs: int, req_pool_indices, seq_lens
|
52
|
-
):
|
53
|
-
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
54
|
-
raise NotImplementedError()
|
55
|
-
|
56
|
-
def init_forward_metadata_replay_cuda_graph(
|
57
|
-
self, bs: int, req_pool_indices, seq_lens
|
58
|
-
):
|
59
|
-
"""Init the metadata for a forward pass for replying a cuda graph."""
|
60
|
-
raise NotImplementedError()
|
61
|
-
|
62
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
63
|
-
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
64
|
-
raise NotImplementedError()
|
65
|
-
|
66
|
-
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
67
|
-
"""Run forward on an attention layer."""
|
68
|
-
if input_metadata.forward_mode.is_decode():
|
69
|
-
return self.forward_decode(q, k, v, layer, input_metadata)
|
70
|
-
else:
|
71
|
-
return self.forward_extend(q, k, v, layer, input_metadata)
|
72
|
-
|
73
|
-
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
74
|
-
"""Run a forward for decode."""
|
75
|
-
raise NotImplementedError()
|
76
|
-
|
77
|
-
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
78
|
-
"""Run a forward for extend."""
|
79
|
-
raise NotImplementedError()
|
80
|
-
|
81
|
-
|
82
|
-
class FlashInferAttnBackend(AttentionBackend):
|
83
|
-
"""Flashinfer attention kernels."""
|
84
|
-
|
85
|
-
def __init__(self, model_runner: ModelRunner):
|
86
|
-
super().__init__()
|
87
|
-
self.model_runner = model_runner
|
88
|
-
|
89
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
90
|
-
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
91
|
-
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
92
|
-
):
|
93
|
-
self.decode_use_tensor_cores = True
|
94
|
-
else:
|
95
|
-
self.decode_use_tensor_cores = False
|
96
|
-
|
97
|
-
self.workspace_buffer = torch.empty(
|
98
|
-
global_config.flashinfer_workspace_size,
|
99
|
-
dtype=torch.uint8,
|
100
|
-
device="cuda",
|
101
|
-
)
|
102
|
-
|
103
|
-
if model_runner.sliding_window_size is None:
|
104
|
-
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
105
|
-
self.workspace_buffer, "NHD"
|
106
|
-
)
|
107
|
-
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
108
|
-
self.workspace_buffer, "NHD"
|
109
|
-
)
|
110
|
-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
111
|
-
self.workspace_buffer,
|
112
|
-
"NHD",
|
113
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
114
|
-
)
|
115
|
-
else:
|
116
|
-
# Two wrappers: one for sliding window attention and one for full attention.
|
117
|
-
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
118
|
-
self.prefill_wrapper_ragged = None
|
119
|
-
self.prefill_wrapper_paged = []
|
120
|
-
self.decode_wrapper = []
|
121
|
-
for _ in range(2):
|
122
|
-
self.prefill_wrapper_paged.append(
|
123
|
-
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
124
|
-
)
|
125
|
-
self.decode_wrapper.append(
|
126
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
127
|
-
self.workspace_buffer,
|
128
|
-
"NHD",
|
129
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
130
|
-
)
|
131
|
-
)
|
132
|
-
|
133
|
-
self.forward_metadata = None
|
134
|
-
self.cuda_graph_metadata = {}
|
135
|
-
|
136
|
-
def init_forward_metadata(
|
137
|
-
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
138
|
-
):
|
139
|
-
if input_metadata.forward_mode.is_decode():
|
140
|
-
prefix_lens = None
|
141
|
-
use_ragged = False
|
142
|
-
total_num_tokens = None
|
143
|
-
else:
|
144
|
-
prefix_lens = input_metadata.extend_prefix_lens
|
145
|
-
|
146
|
-
# Some heuristics to check whether to use ragged forward
|
147
|
-
use_ragged = False
|
148
|
-
if (
|
149
|
-
torch.sum(input_metadata.seq_lens).item() >= 4096
|
150
|
-
and self.model_runner.sliding_window_size is None
|
151
|
-
):
|
152
|
-
use_ragged = True
|
153
|
-
|
154
|
-
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
155
|
-
|
156
|
-
update_flashinfer_indices(
|
157
|
-
input_metadata.forward_mode,
|
158
|
-
self.model_runner,
|
159
|
-
input_metadata.req_pool_indices,
|
160
|
-
input_metadata.seq_lens,
|
161
|
-
prefix_lens,
|
162
|
-
use_ragged=use_ragged,
|
163
|
-
)
|
164
|
-
|
165
|
-
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
|
166
|
-
|
167
|
-
def init_cuda_graph_state(self, max_bs: int):
|
168
|
-
self.cuda_graph_kv_indptr = torch.zeros(
|
169
|
-
(max_bs + 1,), dtype=torch.int32, device="cuda"
|
170
|
-
)
|
171
|
-
self.cuda_graph_kv_indices = torch.zeros(
|
172
|
-
(max_bs * self.model_runner.model_config.context_len,),
|
173
|
-
dtype=torch.int32,
|
174
|
-
device="cuda",
|
175
|
-
)
|
176
|
-
self.cuda_graph_kv_last_page_len = torch.ones(
|
177
|
-
(max_bs,), dtype=torch.int32, device="cuda"
|
178
|
-
)
|
179
|
-
|
180
|
-
if self.model_runner.sliding_window_size is not None:
|
181
|
-
self.cuda_graph_kv_indptr = [
|
182
|
-
self.cuda_graph_kv_indptr,
|
183
|
-
self.cuda_graph_kv_indptr.clone(),
|
184
|
-
]
|
185
|
-
self.cuda_graph_kv_indices = [
|
186
|
-
self.cuda_graph_kv_indices,
|
187
|
-
self.cuda_graph_kv_indices.clone(),
|
188
|
-
]
|
189
|
-
|
190
|
-
def init_forward_metadata_capture_cuda_graph(
|
191
|
-
self, bs: int, req_pool_indices, seq_lens
|
192
|
-
):
|
193
|
-
if self.model_runner.sliding_window_size is None:
|
194
|
-
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
195
|
-
self.workspace_buffer,
|
196
|
-
"NHD",
|
197
|
-
use_cuda_graph=True,
|
198
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
199
|
-
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
|
200
|
-
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
|
201
|
-
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
202
|
-
)
|
203
|
-
else:
|
204
|
-
decode_wrapper = []
|
205
|
-
for i in range(2):
|
206
|
-
decode_wrapper.append(
|
207
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
208
|
-
self.workspace_buffer,
|
209
|
-
"NHD",
|
210
|
-
use_cuda_graph=True,
|
211
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
212
|
-
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
213
|
-
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
214
|
-
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
|
215
|
-
:bs
|
216
|
-
],
|
217
|
-
)
|
218
|
-
)
|
219
|
-
|
220
|
-
update_flashinfer_indices(
|
221
|
-
ForwardMode.DECODE,
|
222
|
-
self.model_runner,
|
223
|
-
req_pool_indices,
|
224
|
-
seq_lens,
|
225
|
-
None,
|
226
|
-
decode_wrapper,
|
227
|
-
)
|
228
|
-
|
229
|
-
self.cuda_graph_metadata[bs] = decode_wrapper
|
230
|
-
|
231
|
-
self.forward_metadata = (False, None, decode_wrapper)
|
232
|
-
|
233
|
-
def init_forward_metadata_replay_cuda_graph(
|
234
|
-
self, bs: int, req_pool_indices, seq_lens
|
235
|
-
):
|
236
|
-
update_flashinfer_indices(
|
237
|
-
ForwardMode.DECODE,
|
238
|
-
self.model_runner,
|
239
|
-
req_pool_indices[:bs],
|
240
|
-
seq_lens[:bs],
|
241
|
-
None,
|
242
|
-
self.cuda_graph_metadata[bs],
|
243
|
-
)
|
244
|
-
|
245
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
246
|
-
return 0
|
247
|
-
|
248
|
-
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
249
|
-
if not isinstance(self.prefill_wrapper_paged, list):
|
250
|
-
prefill_wrapper_paged = self.prefill_wrapper_paged
|
251
|
-
else:
|
252
|
-
if layer.sliding_window_size != -1:
|
253
|
-
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
|
254
|
-
else:
|
255
|
-
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
256
|
-
|
257
|
-
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
258
|
-
|
259
|
-
if not use_ragged:
|
260
|
-
if k is not None:
|
261
|
-
assert v is not None
|
262
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
263
|
-
layer.layer_id, input_metadata.out_cache_loc, k, v
|
264
|
-
)
|
265
|
-
o = prefill_wrapper_paged.forward(
|
266
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
267
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
268
|
-
causal=True,
|
269
|
-
sm_scale=layer.scaling,
|
270
|
-
window_left=layer.sliding_window_size,
|
271
|
-
logits_soft_cap=layer.logit_cap,
|
272
|
-
)
|
273
|
-
else:
|
274
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
275
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
276
|
-
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
|
277
|
-
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
278
|
-
causal=True,
|
279
|
-
sm_scale=layer.scaling,
|
280
|
-
logits_soft_cap=layer.logit_cap,
|
281
|
-
)
|
282
|
-
|
283
|
-
if input_metadata.extend_no_prefix:
|
284
|
-
o = o1
|
285
|
-
else:
|
286
|
-
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
287
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
288
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
289
|
-
causal=False,
|
290
|
-
sm_scale=layer.scaling,
|
291
|
-
logits_soft_cap=layer.logit_cap,
|
292
|
-
)
|
293
|
-
|
294
|
-
o, _ = merge_state(o1, s1, o2, s2)
|
295
|
-
|
296
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
297
|
-
layer.layer_id, input_metadata.out_cache_loc, k, v
|
298
|
-
)
|
299
|
-
|
300
|
-
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
301
|
-
|
302
|
-
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
303
|
-
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
304
|
-
|
305
|
-
if isinstance(decode_wrapper, list):
|
306
|
-
if layer.sliding_window_size != -1:
|
307
|
-
decode_wrapper = decode_wrapper[0]
|
308
|
-
else:
|
309
|
-
decode_wrapper = decode_wrapper[1]
|
310
|
-
|
311
|
-
if k is not None:
|
312
|
-
assert v is not None
|
313
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
314
|
-
layer.layer_id, input_metadata.out_cache_loc, k, v
|
315
|
-
)
|
316
|
-
|
317
|
-
o = decode_wrapper.forward(
|
318
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
319
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
320
|
-
sm_scale=layer.scaling,
|
321
|
-
logits_soft_cap=layer.logit_cap,
|
322
|
-
)
|
323
|
-
|
324
|
-
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
325
|
-
|
326
|
-
|
327
|
-
class TritonAttnBackend(AttentionBackend):
|
328
|
-
def __init__(self, model_runner: ModelRunner):
|
329
|
-
# Lazy import to avoid the initialization of cuda context
|
330
|
-
from sglang.srt.layers.triton_attention.decode_attention import (
|
331
|
-
decode_attention_fwd,
|
332
|
-
)
|
333
|
-
from sglang.srt.layers.triton_attention.extend_attention import (
|
334
|
-
extend_attention_fwd,
|
335
|
-
)
|
336
|
-
|
337
|
-
super().__init__()
|
338
|
-
|
339
|
-
self.decode_attention_fwd = decode_attention_fwd
|
340
|
-
self.extend_attention_fwd = extend_attention_fwd
|
341
|
-
self.num_head = (
|
342
|
-
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
343
|
-
)
|
344
|
-
|
345
|
-
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
346
|
-
self.reduce_dtype = torch.float32
|
347
|
-
else:
|
348
|
-
self.reduce_dtype = torch.float16
|
349
|
-
|
350
|
-
self.forward_metadata = None
|
351
|
-
|
352
|
-
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
353
|
-
|
354
|
-
def init_forward_metadata(
|
355
|
-
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
356
|
-
):
|
357
|
-
"""Init auxiliary variables for triton attention backend."""
|
358
|
-
|
359
|
-
if input_metadata.forward_mode.is_decode():
|
360
|
-
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
361
|
-
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
362
|
-
|
363
|
-
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
364
|
-
attn_logits = torch.empty(
|
365
|
-
(self.num_head, total_num_tokens),
|
366
|
-
dtype=self.reduce_dtype,
|
367
|
-
device="cuda",
|
368
|
-
)
|
369
|
-
|
370
|
-
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
371
|
-
max_extend_len = None
|
372
|
-
else:
|
373
|
-
start_loc = attn_logits = max_seq_len = None
|
374
|
-
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
375
|
-
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
376
|
-
|
377
|
-
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
378
|
-
|
379
|
-
def init_cuda_graph_state(self, max_bs: int):
|
380
|
-
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
381
|
-
|
382
|
-
self.cuda_graph_start_loc = torch.zeros(
|
383
|
-
(max_bs,), dtype=torch.int32, device="cuda"
|
384
|
-
)
|
385
|
-
self.cuda_graph_attn_logits = torch.empty(
|
386
|
-
(
|
387
|
-
self.num_head,
|
388
|
-
self.cuda_graph_max_total_num_tokens,
|
389
|
-
),
|
390
|
-
dtype=self.reduce_dtype,
|
391
|
-
device="cuda",
|
392
|
-
)
|
393
|
-
|
394
|
-
def init_forward_metadata_capture_cuda_graph(
|
395
|
-
self, bs: int, req_pool_indices, seq_lens
|
396
|
-
):
|
397
|
-
self.forward_metadata = (
|
398
|
-
self.cuda_graph_start_loc,
|
399
|
-
self.cuda_graph_attn_logits,
|
400
|
-
self.cuda_graph_max_seq_len,
|
401
|
-
None,
|
402
|
-
)
|
403
|
-
|
404
|
-
def init_forward_metadata_replay_cuda_graph(
|
405
|
-
self, bs: int, req_pool_indices, seq_lens
|
406
|
-
):
|
407
|
-
self.cuda_graph_start_loc.zero_()
|
408
|
-
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
409
|
-
|
410
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
411
|
-
return 1
|
412
|
-
|
413
|
-
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
414
|
-
# TODO: reuse the buffer across layers
|
415
|
-
if layer.qk_head_dim != layer.v_head_dim:
|
416
|
-
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
417
|
-
else:
|
418
|
-
o = torch.empty_like(q)
|
419
|
-
|
420
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
421
|
-
layer.layer_id, input_metadata.out_cache_loc, k, v
|
422
|
-
)
|
423
|
-
|
424
|
-
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
425
|
-
self.extend_attention_fwd(
|
426
|
-
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
427
|
-
k.contiguous(),
|
428
|
-
v.contiguous(),
|
429
|
-
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
430
|
-
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
431
|
-
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
432
|
-
input_metadata.req_to_token_pool.req_to_token,
|
433
|
-
input_metadata.req_pool_indices,
|
434
|
-
input_metadata.seq_lens,
|
435
|
-
input_metadata.extend_seq_lens,
|
436
|
-
input_metadata.extend_start_loc,
|
437
|
-
max_extend_len,
|
438
|
-
layer.scaling,
|
439
|
-
layer.logit_cap,
|
440
|
-
)
|
441
|
-
return o
|
442
|
-
|
443
|
-
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
444
|
-
# During torch.compile, there is a bug in rotary_emb that causes the
|
445
|
-
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
446
|
-
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
447
|
-
|
448
|
-
# TODO: reuse the buffer across layers
|
449
|
-
if layer.qk_head_dim != layer.v_head_dim:
|
450
|
-
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
451
|
-
else:
|
452
|
-
o = torch.empty_like(q)
|
453
|
-
|
454
|
-
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
455
|
-
|
456
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
457
|
-
layer.layer_id, input_metadata.out_cache_loc, k, v
|
458
|
-
)
|
459
|
-
|
460
|
-
self.decode_attention_fwd(
|
461
|
-
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
462
|
-
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
463
|
-
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
464
|
-
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
465
|
-
input_metadata.req_to_token_pool.req_to_token,
|
466
|
-
input_metadata.req_pool_indices,
|
467
|
-
start_loc,
|
468
|
-
input_metadata.seq_lens,
|
469
|
-
attn_logits,
|
470
|
-
max_seq_len,
|
471
|
-
layer.scaling,
|
472
|
-
layer.logit_cap,
|
473
|
-
)
|
474
|
-
return o
|
@@ -1,207 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Copyright 2023-2024 SGLang Team
|
3
|
-
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
you may not use this file except in compliance with the License.
|
5
|
-
You may obtain a copy of the License at
|
6
|
-
|
7
|
-
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
|
9
|
-
Unless required by applicable law or agreed to in writing, software
|
10
|
-
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
See the License for the specific language governing permissions and
|
13
|
-
limitations under the License.
|
14
|
-
"""
|
15
|
-
|
16
|
-
"""
|
17
|
-
A controller that manages multiple data parallel workers.
|
18
|
-
Each data parallel worker can manage multiple tensor parallel workers.
|
19
|
-
"""
|
20
|
-
|
21
|
-
import dataclasses
|
22
|
-
import logging
|
23
|
-
import multiprocessing
|
24
|
-
from enum import Enum, auto
|
25
|
-
|
26
|
-
import numpy as np
|
27
|
-
import zmq
|
28
|
-
|
29
|
-
from sglang.srt.managers.controller_single import (
|
30
|
-
start_controller_process as start_controller_process_single,
|
31
|
-
)
|
32
|
-
from sglang.srt.managers.io_struct import (
|
33
|
-
AbortReq,
|
34
|
-
FlushCacheReq,
|
35
|
-
TokenizedGenerateReqInput,
|
36
|
-
)
|
37
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
38
|
-
from sglang.srt.utils import configure_logger, kill_parent_process
|
39
|
-
from sglang.utils import get_exception_traceback
|
40
|
-
|
41
|
-
logger = logging.getLogger(__name__)
|
42
|
-
|
43
|
-
|
44
|
-
class LoadBalanceMethod(Enum):
|
45
|
-
"""Load balance method."""
|
46
|
-
|
47
|
-
ROUND_ROBIN = auto()
|
48
|
-
SHORTEST_QUEUE = auto()
|
49
|
-
|
50
|
-
@classmethod
|
51
|
-
def from_str(cls, method: str):
|
52
|
-
method = method.upper()
|
53
|
-
try:
|
54
|
-
return cls[method]
|
55
|
-
except KeyError as exc:
|
56
|
-
raise ValueError(f"Invalid load balance method: {method}") from exc
|
57
|
-
|
58
|
-
|
59
|
-
@dataclasses.dataclass
|
60
|
-
class WorkerHandle:
|
61
|
-
"""Store the handle of a data parallel worker."""
|
62
|
-
|
63
|
-
proc: multiprocessing.Process
|
64
|
-
queue: multiprocessing.Queue
|
65
|
-
|
66
|
-
|
67
|
-
class ControllerMulti:
|
68
|
-
"""A controller that manages multiple data parallel workers."""
|
69
|
-
|
70
|
-
def __init__(
|
71
|
-
self,
|
72
|
-
server_args: ServerArgs,
|
73
|
-
port_args: PortArgs,
|
74
|
-
):
|
75
|
-
# Parse args
|
76
|
-
self.server_args = server_args
|
77
|
-
self.port_args = port_args
|
78
|
-
self.load_balance_method = LoadBalanceMethod.from_str(
|
79
|
-
server_args.load_balance_method
|
80
|
-
)
|
81
|
-
|
82
|
-
# Init communication
|
83
|
-
context = zmq.Context()
|
84
|
-
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
85
|
-
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
86
|
-
|
87
|
-
# Dispatch method
|
88
|
-
self.round_robin_counter = 0
|
89
|
-
dispatch_lookup = {
|
90
|
-
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
91
|
-
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
92
|
-
}
|
93
|
-
self.dispatching = dispatch_lookup[self.load_balance_method]
|
94
|
-
|
95
|
-
# Start data parallel workers
|
96
|
-
self.workers = []
|
97
|
-
for i in range(server_args.dp_size):
|
98
|
-
self.start_dp_worker(i)
|
99
|
-
|
100
|
-
def start_dp_worker(self, dp_worker_id: int):
|
101
|
-
tp_size = self.server_args.tp_size
|
102
|
-
|
103
|
-
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
|
104
|
-
duplex=False
|
105
|
-
)
|
106
|
-
|
107
|
-
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
108
|
-
queue = multiprocessing.Queue()
|
109
|
-
proc = multiprocessing.Process(
|
110
|
-
target=start_controller_process_single,
|
111
|
-
args=(
|
112
|
-
self.server_args,
|
113
|
-
self.port_args,
|
114
|
-
pipe_controller_writer,
|
115
|
-
True,
|
116
|
-
gpu_ids,
|
117
|
-
dp_worker_id,
|
118
|
-
queue,
|
119
|
-
),
|
120
|
-
)
|
121
|
-
proc.start()
|
122
|
-
|
123
|
-
controller_init_state = pipe_controller_reader.recv()
|
124
|
-
if controller_init_state != "init ok":
|
125
|
-
raise RuntimeError(
|
126
|
-
f"Initialization failed. controller_init_state: {controller_init_state}"
|
127
|
-
)
|
128
|
-
self.workers.append(
|
129
|
-
WorkerHandle(
|
130
|
-
proc=proc,
|
131
|
-
queue=queue,
|
132
|
-
)
|
133
|
-
)
|
134
|
-
|
135
|
-
def round_robin_scheduler(self, input_requests):
|
136
|
-
for r in input_requests:
|
137
|
-
self.workers[self.round_robin_counter].queue.put(r)
|
138
|
-
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
139
|
-
self.workers
|
140
|
-
)
|
141
|
-
|
142
|
-
def shortest_queue_scheduler(self, input_requests):
|
143
|
-
for r in input_requests:
|
144
|
-
queue_sizes = [worker.queue.qsize() for worker in self.workers]
|
145
|
-
wid = np.argmin(queue_sizes)
|
146
|
-
self.workers[wid].queue.put(r)
|
147
|
-
|
148
|
-
def loop_for_forward(self):
|
149
|
-
while True:
|
150
|
-
recv_reqs = self.recv_requests()
|
151
|
-
self.dispatching(recv_reqs)
|
152
|
-
|
153
|
-
def recv_requests(self):
|
154
|
-
recv_reqs = []
|
155
|
-
|
156
|
-
while True:
|
157
|
-
try:
|
158
|
-
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
159
|
-
except zmq.ZMQError:
|
160
|
-
break
|
161
|
-
|
162
|
-
if isinstance(recv_req, FlushCacheReq):
|
163
|
-
# TODO(lsyin): apply more specific flushCacheReq
|
164
|
-
for worker in self.workers:
|
165
|
-
worker.queue.put(recv_req)
|
166
|
-
elif isinstance(recv_req, AbortReq):
|
167
|
-
in_queue = False
|
168
|
-
for i, req in enumerate(recv_reqs):
|
169
|
-
if req.rid == recv_req.rid:
|
170
|
-
recv_reqs[i] = recv_req
|
171
|
-
in_queue = True
|
172
|
-
break
|
173
|
-
if not in_queue:
|
174
|
-
# Send abort req to all TP groups
|
175
|
-
for worker in self.workers:
|
176
|
-
worker.queue.put(recv_req)
|
177
|
-
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
178
|
-
recv_reqs.append(recv_req)
|
179
|
-
else:
|
180
|
-
logger.error(f"Invalid object: {recv_req}")
|
181
|
-
|
182
|
-
return recv_reqs
|
183
|
-
|
184
|
-
|
185
|
-
def start_controller_process(
|
186
|
-
server_args: ServerArgs,
|
187
|
-
port_args: PortArgs,
|
188
|
-
pipe_writer,
|
189
|
-
):
|
190
|
-
"""Start a controller process."""
|
191
|
-
|
192
|
-
configure_logger(server_args)
|
193
|
-
|
194
|
-
try:
|
195
|
-
controller = ControllerMulti(server_args, port_args)
|
196
|
-
except Exception:
|
197
|
-
pipe_writer.send(get_exception_traceback())
|
198
|
-
raise
|
199
|
-
|
200
|
-
pipe_writer.send("init ok")
|
201
|
-
|
202
|
-
try:
|
203
|
-
controller.loop_for_forward()
|
204
|
-
except Exception:
|
205
|
-
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
206
|
-
finally:
|
207
|
-
kill_parent_process()
|