sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -20
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -1
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
- sglang/srt/managers/controller/infer_batch.py +76 -72
- sglang/srt/managers/controller/manager_multi.py +109 -98
- sglang/srt/managers/controller/manager_single.py +105 -50
- sglang/srt/managers/controller/model_runner.py +42 -18
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +143 -156
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +46 -58
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +2 -8
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +130 -108
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +114 -90
- sglang/srt/server_args.py +27 -17
- sglang/srt/utils.py +17 -118
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.20.dist-info/RECORD +0 -82
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
|
|
1
1
|
"""A tensor parallel worker."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import logging
|
4
|
+
import multiprocessing
|
5
|
+
import pickle
|
5
6
|
import time
|
6
7
|
import warnings
|
7
|
-
from concurrent.futures import ThreadPoolExecutor
|
8
8
|
from typing import List, Optional
|
9
9
|
|
10
|
-
import rpyc
|
11
10
|
import torch
|
12
|
-
|
11
|
+
import torch.distributed as dist
|
13
12
|
|
14
13
|
from sglang.global_config import global_config
|
15
14
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
|
|
32
31
|
TokenizedGenerateReqInput,
|
33
32
|
)
|
34
33
|
from sglang.srt.model_config import ModelConfig
|
35
|
-
from sglang.srt.server_args import
|
34
|
+
from sglang.srt.server_args import ServerArgs
|
36
35
|
from sglang.srt.utils import (
|
37
|
-
connect_rpyc_service,
|
38
36
|
get_int_token_logit_bias,
|
39
37
|
is_multimodal_model,
|
40
38
|
set_random_seed,
|
41
|
-
start_rpyc_service_process,
|
42
39
|
suppress_other_loggers,
|
43
40
|
)
|
44
41
|
from sglang.utils import get_exception_traceback
|
@@ -52,10 +49,9 @@ class ModelTpServer:
|
|
52
49
|
gpu_id: int,
|
53
50
|
tp_rank: int,
|
54
51
|
server_args: ServerArgs,
|
55
|
-
|
56
|
-
model_overide_args,
|
52
|
+
nccl_port: int,
|
53
|
+
model_overide_args: dict,
|
57
54
|
):
|
58
|
-
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
59
55
|
suppress_other_loggers()
|
60
56
|
|
61
57
|
# Copy arguments
|
@@ -79,7 +75,7 @@ class ModelTpServer:
|
|
79
75
|
gpu_id=gpu_id,
|
80
76
|
tp_rank=tp_rank,
|
81
77
|
tp_size=server_args.tp_size,
|
82
|
-
nccl_port=
|
78
|
+
nccl_port=nccl_port,
|
83
79
|
server_args=server_args,
|
84
80
|
)
|
85
81
|
|
@@ -98,7 +94,7 @@ class ModelTpServer:
|
|
98
94
|
)
|
99
95
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
100
96
|
self.max_prefill_tokens = (
|
101
|
-
|
97
|
+
16384
|
102
98
|
if server_args.max_prefill_tokens is None
|
103
99
|
else server_args.max_prefill_tokens
|
104
100
|
)
|
@@ -178,9 +174,6 @@ class ModelTpServer:
|
|
178
174
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
179
175
|
|
180
176
|
def exposed_step(self, recv_reqs):
|
181
|
-
if self.tp_size * self.dp_size != 1:
|
182
|
-
recv_reqs = obtain(recv_reqs)
|
183
|
-
|
184
177
|
try:
|
185
178
|
# Recv requests
|
186
179
|
for recv_req in recv_reqs:
|
@@ -206,11 +199,11 @@ class ModelTpServer:
|
|
206
199
|
|
207
200
|
@torch.inference_mode()
|
208
201
|
def forward_step(self):
|
209
|
-
new_batch = self.
|
202
|
+
new_batch = self.get_new_prefill_batch()
|
210
203
|
|
211
204
|
if new_batch is not None:
|
212
|
-
# Run a new
|
213
|
-
self.
|
205
|
+
# Run a new prefill batch
|
206
|
+
self.forward_prefill_batch(new_batch)
|
214
207
|
self.cache_filled_batch(new_batch)
|
215
208
|
|
216
209
|
if not new_batch.is_empty():
|
@@ -219,33 +212,16 @@ class ModelTpServer:
|
|
219
212
|
else:
|
220
213
|
self.running_batch.merge(new_batch)
|
221
214
|
else:
|
222
|
-
# Run decode batch
|
215
|
+
# Run a decode batch
|
223
216
|
if self.running_batch is not None:
|
224
217
|
# Run a few decode batches continuously for reducing overhead
|
225
|
-
for _ in range(
|
218
|
+
for _ in range(global_config.num_continue_decode_steps):
|
226
219
|
self.num_generated_tokens += len(self.running_batch.reqs)
|
227
220
|
self.forward_decode_batch(self.running_batch)
|
228
221
|
|
229
222
|
# Print stats
|
230
|
-
if self.tp_rank == 0:
|
231
|
-
|
232
|
-
num_used = self.max_total_num_tokens - (
|
233
|
-
self.token_to_kv_pool.available_size()
|
234
|
-
+ self.tree_cache.evictable_size()
|
235
|
-
)
|
236
|
-
throughput = self.num_generated_tokens / (
|
237
|
-
time.time() - self.last_stats_tic
|
238
|
-
)
|
239
|
-
self.num_generated_tokens = 0
|
240
|
-
self.last_stats_tic = time.time()
|
241
|
-
logger.info(
|
242
|
-
f"[gpu_id={self.gpu_id}] Decode batch. "
|
243
|
-
f"#running-req: {len(self.running_batch.reqs)}, "
|
244
|
-
f"#token: {num_used}, "
|
245
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
246
|
-
f"gen throughput (token/s): {throughput:.2f}, "
|
247
|
-
f"#queue-req: {len(self.forward_queue)}"
|
248
|
-
)
|
223
|
+
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
224
|
+
self.print_stats()
|
249
225
|
|
250
226
|
if self.running_batch.is_empty():
|
251
227
|
self.running_batch = None
|
@@ -254,17 +230,34 @@ class ModelTpServer:
|
|
254
230
|
if self.out_pyobjs and self.running_batch.has_stream():
|
255
231
|
break
|
256
232
|
else:
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
233
|
+
self.check_memory()
|
234
|
+
|
235
|
+
def print_stats(self):
|
236
|
+
num_used = self.max_total_num_tokens - (
|
237
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
238
|
+
)
|
239
|
+
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
240
|
+
self.num_generated_tokens = 0
|
241
|
+
self.last_stats_tic = time.time()
|
242
|
+
logger.info(
|
243
|
+
f"[gpu_id={self.gpu_id}] Decode batch. "
|
244
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
245
|
+
f"#token: {num_used}, "
|
246
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
247
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
248
|
+
f"#queue-req: {len(self.forward_queue)}"
|
249
|
+
)
|
250
|
+
|
251
|
+
def check_memory(self):
|
252
|
+
available_size = (
|
253
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
254
|
+
)
|
255
|
+
if available_size != self.max_total_num_tokens:
|
256
|
+
warnings.warn(
|
257
|
+
"Warning: "
|
258
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
259
|
+
"KV cache pool leak detected!"
|
260
|
+
)
|
268
261
|
|
269
262
|
def handle_generate_request(
|
270
263
|
self,
|
@@ -311,10 +304,18 @@ class ModelTpServer:
|
|
311
304
|
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
312
305
|
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
313
306
|
)
|
307
|
+
if req.sampling_params.max_new_tokens < 0:
|
308
|
+
req.origin_input_ids = req.origin_input_ids[
|
309
|
+
: self.max_total_num_tokens - 128
|
310
|
+
]
|
311
|
+
logger.error("Request longer than memory pool size, truncated!!!")
|
312
|
+
|
314
313
|
self.forward_queue.append(req)
|
315
314
|
|
316
|
-
def
|
317
|
-
running_bs =
|
315
|
+
def get_new_prefill_batch(self) -> Optional[Batch]:
|
316
|
+
running_bs = (
|
317
|
+
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
318
|
+
)
|
318
319
|
if running_bs >= self.max_running_requests:
|
319
320
|
return
|
320
321
|
|
@@ -342,7 +343,8 @@ class ModelTpServer:
|
|
342
343
|
if self.running_batch:
|
343
344
|
available_size -= sum(
|
344
345
|
[
|
345
|
-
(r.max_new_tokens
|
346
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
347
|
+
* self.new_token_ratio
|
346
348
|
for r in self.running_batch.reqs
|
347
349
|
]
|
348
350
|
)
|
@@ -356,7 +358,7 @@ class ModelTpServer:
|
|
356
358
|
req.prefix_indices = req.prefix_indices[:-delta]
|
357
359
|
if req.image_offset is not None:
|
358
360
|
req.image_offset += delta
|
359
|
-
if req.extend_input_len == 0 and req.max_new_tokens
|
361
|
+
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
360
362
|
# Need at least one token to compute logits
|
361
363
|
req.extend_input_len = 1
|
362
364
|
req.prefix_indices = req.prefix_indices[:-1]
|
@@ -364,7 +366,9 @@ class ModelTpServer:
|
|
364
366
|
req.image_offset += 1
|
365
367
|
|
366
368
|
if (
|
367
|
-
req.extend_input_len
|
369
|
+
req.extend_input_len
|
370
|
+
+ req.sampling_params.max_new_tokens
|
371
|
+
+ new_batch_total_tokens
|
368
372
|
< available_size
|
369
373
|
and (
|
370
374
|
req.extend_input_len + new_batch_input_tokens
|
@@ -376,7 +380,9 @@ class ModelTpServer:
|
|
376
380
|
available_size += delta
|
377
381
|
|
378
382
|
if not (
|
379
|
-
req.extend_input_len
|
383
|
+
req.extend_input_len
|
384
|
+
+ req.sampling_params.max_new_tokens
|
385
|
+
+ new_batch_total_tokens
|
380
386
|
< available_size
|
381
387
|
):
|
382
388
|
# Undo locking
|
@@ -387,7 +393,7 @@ class ModelTpServer:
|
|
387
393
|
# Add this request to the running batch
|
388
394
|
can_run_list.append(req)
|
389
395
|
new_batch_total_tokens += (
|
390
|
-
req.extend_input_len + req.max_new_tokens
|
396
|
+
req.extend_input_len + req.sampling_params.max_new_tokens
|
391
397
|
)
|
392
398
|
new_batch_input_tokens += req.extend_input_len
|
393
399
|
else:
|
@@ -401,9 +407,6 @@ class ModelTpServer:
|
|
401
407
|
|
402
408
|
# Print stats
|
403
409
|
if self.tp_rank == 0:
|
404
|
-
running_req = (
|
405
|
-
0 if self.running_batch is None else len(self.running_batch.reqs)
|
406
|
-
)
|
407
410
|
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
408
411
|
self.tree_cache_metrics["total"] += (
|
409
412
|
hit_tokens + new_batch_input_tokens
|
@@ -418,15 +421,9 @@ class ModelTpServer:
|
|
418
421
|
f"#new-token: {new_batch_input_tokens}, "
|
419
422
|
f"#cached-token: {hit_tokens}, "
|
420
423
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
421
|
-
f"#running-req: {
|
424
|
+
f"#running-req: {running_bs}, "
|
422
425
|
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
423
426
|
)
|
424
|
-
# logger.debug(
|
425
|
-
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
426
|
-
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
427
|
-
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
428
|
-
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
429
|
-
# )
|
430
427
|
|
431
428
|
# Return the new batch
|
432
429
|
new_batch = Batch.init_new(
|
@@ -438,7 +435,7 @@ class ModelTpServer:
|
|
438
435
|
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
439
436
|
return new_batch
|
440
437
|
|
441
|
-
def
|
438
|
+
def forward_prefill_batch(self, batch: Batch):
|
442
439
|
# Build batch tensors
|
443
440
|
batch.prepare_for_extend(
|
444
441
|
self.model_config.vocab_size, self.int_token_logit_bias
|
@@ -447,7 +444,7 @@ class ModelTpServer:
|
|
447
444
|
# Forward and sample the next tokens
|
448
445
|
if batch.extend_num_tokens != 0:
|
449
446
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
450
|
-
next_token_ids
|
447
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
451
448
|
|
452
449
|
# Move logprobs to cpu
|
453
450
|
if output.next_token_logprobs is not None:
|
@@ -570,7 +567,7 @@ class ModelTpServer:
|
|
570
567
|
|
571
568
|
# Forward and sample the next tokens
|
572
569
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
573
|
-
next_token_ids
|
570
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
574
571
|
|
575
572
|
# Move logprobs to cpu
|
576
573
|
if output.next_token_logprobs is not None:
|
@@ -598,9 +595,10 @@ class ModelTpServer:
|
|
598
595
|
|
599
596
|
def handle_finished_requests(self, batch: Batch):
|
600
597
|
output_rids = []
|
598
|
+
output_vids = []
|
601
599
|
decoded_texts = []
|
602
|
-
|
603
|
-
|
600
|
+
output_read_ids = []
|
601
|
+
output_read_offsets = []
|
604
602
|
output_skip_special_tokens = []
|
605
603
|
output_spaces_between_special_tokens = []
|
606
604
|
output_meta_info = []
|
@@ -623,10 +621,11 @@ class ModelTpServer:
|
|
623
621
|
)
|
624
622
|
):
|
625
623
|
output_rids.append(req.rid)
|
624
|
+
output_vids.append(req.vid)
|
626
625
|
decoded_texts.append(req.decoded_text)
|
627
|
-
|
628
|
-
|
629
|
-
|
626
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
627
|
+
output_read_ids.append(read_ids)
|
628
|
+
output_read_offsets.append(read_offset)
|
630
629
|
output_skip_special_tokens.append(
|
631
630
|
req.sampling_params.skip_special_tokens
|
632
631
|
)
|
@@ -662,9 +661,10 @@ class ModelTpServer:
|
|
662
661
|
self.out_pyobjs.append(
|
663
662
|
BatchTokenIDOut(
|
664
663
|
output_rids,
|
664
|
+
output_vids,
|
665
665
|
decoded_texts,
|
666
|
-
|
667
|
-
|
666
|
+
output_read_ids,
|
667
|
+
output_read_offsets,
|
668
668
|
output_skip_special_tokens,
|
669
669
|
output_spaces_between_special_tokens,
|
670
670
|
output_meta_info,
|
@@ -729,87 +729,74 @@ class ModelTpServer:
|
|
729
729
|
break
|
730
730
|
|
731
731
|
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
732
|
+
def run_tp_server(
|
733
|
+
gpu_id: int,
|
734
|
+
tp_rank: int,
|
735
|
+
server_args: ServerArgs,
|
736
|
+
nccl_port: int,
|
737
|
+
model_overide_args: dict,
|
738
|
+
):
|
739
|
+
"""Run a tensor parallel server."""
|
740
|
+
try:
|
741
|
+
model_server = ModelTpServer(
|
742
|
+
gpu_id,
|
743
|
+
tp_rank,
|
744
|
+
server_args,
|
745
|
+
nccl_port,
|
746
|
+
model_overide_args,
|
747
|
+
)
|
748
|
+
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
749
|
+
|
750
|
+
while True:
|
751
|
+
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
752
|
+
model_server.exposed_step(recv_reqs)
|
753
|
+
except Exception:
|
754
|
+
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
755
|
+
raise
|
756
|
+
|
757
|
+
|
758
|
+
def launch_tp_servers(
|
759
|
+
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
|
760
|
+
):
|
761
|
+
"""Launch multiple tensor parallel servers."""
|
762
|
+
procs = []
|
763
|
+
for i in tp_rank_range:
|
764
|
+
proc = multiprocessing.Process(
|
765
|
+
target=run_tp_server,
|
766
|
+
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
|
767
|
+
)
|
768
|
+
proc.start()
|
769
|
+
procs.append(proc)
|
746
770
|
|
747
|
-
|
748
|
-
# Init model
|
749
|
-
assert len(gpu_ids) == 1
|
750
|
-
self.model_server = ModelTpService().exposed_ModelTpServer(
|
751
|
-
0,
|
752
|
-
gpu_ids[0],
|
753
|
-
server_args,
|
754
|
-
model_port_args,
|
755
|
-
model_overide_args,
|
756
|
-
)
|
771
|
+
return procs
|
757
772
|
|
758
|
-
# Wrap functions
|
759
|
-
def async_wrap(f):
|
760
|
-
async def _func(*args, **kwargs):
|
761
|
-
return f(*args, **kwargs)
|
762
773
|
|
763
|
-
|
774
|
+
def broadcast_recv_input(data, rank, dist_group):
|
775
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
764
776
|
|
765
|
-
|
777
|
+
if rank == 0:
|
778
|
+
if len(data) == 0:
|
779
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
780
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
766
781
|
else:
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
self.model_services = list(
|
789
|
-
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
790
|
-
)
|
791
|
-
|
792
|
-
# Init model
|
793
|
-
def init_model(i):
|
794
|
-
return self.model_services[i].ModelTpServer(
|
795
|
-
gpu_ids[i],
|
796
|
-
i,
|
797
|
-
server_args,
|
798
|
-
model_port_args,
|
799
|
-
model_overide_args,
|
800
|
-
)
|
801
|
-
|
802
|
-
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
803
|
-
|
804
|
-
# Wrap functions
|
805
|
-
def async_wrap(func_name):
|
806
|
-
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
807
|
-
|
808
|
-
async def _func(*args, **kwargs):
|
809
|
-
tasks = [f(*args, **kwargs) for f in fs]
|
810
|
-
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
811
|
-
return obtain(tasks[0].value)
|
812
|
-
|
813
|
-
return _func
|
814
|
-
|
815
|
-
self.step = async_wrap("step")
|
782
|
+
serialized_data = pickle.dumps(data)
|
783
|
+
size = len(serialized_data)
|
784
|
+
tensor_data = torch.ByteTensor(list(serialized_data))
|
785
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
786
|
+
|
787
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
788
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
789
|
+
else:
|
790
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
791
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
792
|
+
size = tensor_size.item()
|
793
|
+
|
794
|
+
if size == 0:
|
795
|
+
return []
|
796
|
+
|
797
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
798
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
799
|
+
|
800
|
+
serialized_data = bytes(tensor_data.tolist())
|
801
|
+
data = pickle.loads(serialized_data)
|
802
|
+
return data
|
@@ -1,7 +1,9 @@
|
|
1
1
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import dataclasses
|
4
5
|
import inspect
|
6
|
+
from typing import List
|
5
7
|
|
6
8
|
import uvloop
|
7
9
|
import zmq
|
@@ -16,6 +18,15 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
|
|
16
18
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
17
19
|
|
18
20
|
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class DecodeStatus:
|
23
|
+
vid: int
|
24
|
+
decoded_text: str
|
25
|
+
decode_ids: List[int]
|
26
|
+
surr_offset: int
|
27
|
+
read_offset: int
|
28
|
+
|
29
|
+
|
19
30
|
class DetokenizerManager:
|
20
31
|
def __init__(
|
21
32
|
self,
|
@@ -35,19 +46,43 @@ class DetokenizerManager:
|
|
35
46
|
trust_remote_code=server_args.trust_remote_code,
|
36
47
|
)
|
37
48
|
|
49
|
+
self.decode_status = {}
|
50
|
+
|
38
51
|
async def handle_loop(self):
|
39
52
|
while True:
|
40
53
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
41
54
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
55
|
+
bs = len(recv_obj.rids)
|
56
|
+
|
57
|
+
# Initialize decode status
|
58
|
+
read_ids, surr_ids = [], []
|
59
|
+
for i in range(bs):
|
60
|
+
rid = recv_obj.rids[i]
|
61
|
+
vid = recv_obj.vids[i]
|
62
|
+
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
63
|
+
s = DecodeStatus(
|
64
|
+
vid=vid,
|
65
|
+
decoded_text=recv_obj.decoded_texts[i],
|
66
|
+
decode_ids=recv_obj.decode_ids[i],
|
67
|
+
surr_offset=0,
|
68
|
+
read_offset=recv_obj.read_offsets[i],
|
69
|
+
)
|
70
|
+
self.decode_status[rid] = s
|
71
|
+
else:
|
72
|
+
s = self.decode_status[rid]
|
73
|
+
s.decode_ids = recv_obj.decode_ids[i]
|
74
|
+
|
75
|
+
read_ids.append(s.decode_ids[s.surr_offset :])
|
76
|
+
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
42
77
|
|
43
78
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
79
|
surr_texts = self.tokenizer.batch_decode(
|
45
|
-
|
80
|
+
surr_ids,
|
46
81
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
82
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
83
|
)
|
49
84
|
read_texts = self.tokenizer.batch_decode(
|
50
|
-
|
85
|
+
read_ids,
|
51
86
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
52
87
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
53
88
|
)
|
@@ -55,11 +90,20 @@ class DetokenizerManager:
|
|
55
90
|
# Trim stop str
|
56
91
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
57
92
|
output_strs = []
|
58
|
-
for i in range(
|
93
|
+
for i in range(bs):
|
94
|
+
s = self.decode_status[recv_obj.rids[i]]
|
59
95
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
96
|
if recv_obj.finished_reason[i] is None:
|
61
|
-
|
62
|
-
|
97
|
+
# Streaming chunk: update the decode status
|
98
|
+
if len(new_text) > 0 and not new_text.endswith("�"):
|
99
|
+
s.decoded_text = s.decoded_text + new_text
|
100
|
+
s.surr_offset = s.read_offset
|
101
|
+
s.read_offset = len(s.decode_ids)
|
102
|
+
new_text = ""
|
103
|
+
else:
|
104
|
+
new_text = find_printable_text(new_text)
|
105
|
+
|
106
|
+
output_strs.append(s.decoded_text + new_text)
|
63
107
|
|
64
108
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
65
109
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -13,25 +13,26 @@ from sglang.srt.sampling_params import SamplingParams
|
|
13
13
|
|
14
14
|
@dataclass
|
15
15
|
class GenerateReqInput:
|
16
|
-
# The input prompt
|
16
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
17
17
|
text: Optional[Union[List[str], str]] = None
|
18
|
-
# The token ids for text; one can either specify text or input_ids
|
18
|
+
# The token ids for text; one can either specify text or input_ids.
|
19
19
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
20
|
-
# The image input
|
20
|
+
# The image input. It can be a file name, a url, or base64 encoded string.
|
21
|
+
# See also python/sglang/srt/utils.py:load_image.
|
21
22
|
image_data: Optional[Union[List[str], str]] = None
|
22
|
-
# The sampling_params
|
23
|
+
# The sampling_params.
|
23
24
|
sampling_params: Union[List[Dict], Dict] = None
|
24
|
-
# The request id
|
25
|
+
# The request id.
|
25
26
|
rid: Optional[Union[List[str], str]] = None
|
26
|
-
# Whether to return logprobs
|
27
|
+
# Whether to return logprobs.
|
27
28
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
28
|
-
# The start location of the prompt for return_logprob
|
29
|
+
# The start location of the prompt for return_logprob.
|
29
30
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
30
|
-
# The number of top logprobs to return
|
31
|
+
# The number of top logprobs to return.
|
31
32
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
32
|
-
# Whether to detokenize tokens in logprobs
|
33
|
+
# Whether to detokenize tokens in logprobs.
|
33
34
|
return_text_in_logprobs: bool = False
|
34
|
-
# Whether to stream output
|
35
|
+
# Whether to stream output.
|
35
36
|
stream: bool = False
|
36
37
|
|
37
38
|
def post_init(self):
|
@@ -39,11 +40,13 @@ class GenerateReqInput:
|
|
39
40
|
self.text is not None and self.input_ids is not None
|
40
41
|
):
|
41
42
|
raise ValueError("Either text or input_ids should be provided.")
|
42
|
-
|
43
|
-
|
44
|
-
is_single = isinstance(self.text, str)
|
43
|
+
if self.sampling_params.get("n", 1) != 1:
|
44
|
+
is_single = False
|
45
45
|
else:
|
46
|
-
|
46
|
+
if self.text is not None:
|
47
|
+
is_single = isinstance(self.text, str)
|
48
|
+
else:
|
49
|
+
is_single = isinstance(self.input_ids[0], int)
|
47
50
|
self.is_single = is_single
|
48
51
|
|
49
52
|
if is_single:
|
@@ -58,7 +61,22 @@ class GenerateReqInput:
|
|
58
61
|
if self.top_logprobs_num is None:
|
59
62
|
self.top_logprobs_num = 0
|
60
63
|
else:
|
61
|
-
|
64
|
+
|
65
|
+
parallel_sample_num = self.sampling_params.get("n", 1)
|
66
|
+
|
67
|
+
if parallel_sample_num != 1:
|
68
|
+
# parallel sampling +1 represents the original prefill stage
|
69
|
+
num = parallel_sample_num + 1
|
70
|
+
if isinstance(self.text, List):
|
71
|
+
## suppot batch operation
|
72
|
+
self.batch_size = len(self.text)
|
73
|
+
num = num * len(self.text)
|
74
|
+
else:
|
75
|
+
self.batch_size = 1
|
76
|
+
else:
|
77
|
+
## support select operation
|
78
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
79
|
+
self.batch_size = num
|
62
80
|
|
63
81
|
if self.image_data is None:
|
64
82
|
self.image_data = [None] * num
|
@@ -110,9 +128,10 @@ class TokenizedGenerateReqInput:
|
|
110
128
|
@dataclass
|
111
129
|
class BatchTokenIDOut:
|
112
130
|
rids: List[str]
|
131
|
+
vids: List[int]
|
113
132
|
decoded_texts: List[str]
|
114
|
-
|
115
|
-
|
133
|
+
decode_ids: List[int]
|
134
|
+
read_offsets: List[int]
|
116
135
|
skip_special_tokens: List[bool]
|
117
136
|
spaces_between_special_tokens: List[bool]
|
118
137
|
meta_info: List[Dict]
|