sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +976 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -2
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +39 -24
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
- sglang/srt/managers/controller/infer_batch.py +90 -63
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +41 -26
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +136 -149
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +32 -11
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +81 -23
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +132 -84
- sglang/srt/server_args.py +35 -21
- sglang/srt/utils.py +65 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
- sglang-0.1.24.dist-info/RECORD +105 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
|
|
1
1
|
"""A tensor parallel worker."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import logging
|
4
|
+
import multiprocessing
|
5
|
+
import pickle
|
5
6
|
import time
|
6
7
|
import warnings
|
7
|
-
from concurrent.futures import ThreadPoolExecutor
|
8
8
|
from typing import List, Optional
|
9
9
|
|
10
|
-
import rpyc
|
11
10
|
import torch
|
12
|
-
|
11
|
+
import torch.distributed as dist
|
13
12
|
|
14
13
|
from sglang.global_config import global_config
|
15
14
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
|
|
32
31
|
TokenizedGenerateReqInput,
|
33
32
|
)
|
34
33
|
from sglang.srt.model_config import ModelConfig
|
35
|
-
from sglang.srt.server_args import
|
34
|
+
from sglang.srt.server_args import ServerArgs
|
36
35
|
from sglang.srt.utils import (
|
37
|
-
connect_rpyc_service,
|
38
36
|
get_int_token_logit_bias,
|
39
37
|
is_multimodal_model,
|
40
38
|
set_random_seed,
|
41
|
-
start_rpyc_service_process,
|
42
39
|
suppress_other_loggers,
|
43
40
|
)
|
44
41
|
from sglang.utils import get_exception_traceback
|
@@ -52,10 +49,9 @@ class ModelTpServer:
|
|
52
49
|
gpu_id: int,
|
53
50
|
tp_rank: int,
|
54
51
|
server_args: ServerArgs,
|
55
|
-
|
52
|
+
nccl_port: int,
|
56
53
|
model_overide_args: dict,
|
57
54
|
):
|
58
|
-
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
59
55
|
suppress_other_loggers()
|
60
56
|
|
61
57
|
# Copy arguments
|
@@ -79,7 +75,7 @@ class ModelTpServer:
|
|
79
75
|
gpu_id=gpu_id,
|
80
76
|
tp_rank=tp_rank,
|
81
77
|
tp_size=server_args.tp_size,
|
82
|
-
nccl_port=
|
78
|
+
nccl_port=nccl_port,
|
83
79
|
server_args=server_args,
|
84
80
|
)
|
85
81
|
|
@@ -107,6 +103,9 @@ class ModelTpServer:
|
|
107
103
|
if server_args.max_running_requests is None
|
108
104
|
else server_args.max_running_requests
|
109
105
|
)
|
106
|
+
self.max_running_requests = min(
|
107
|
+
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
|
108
|
+
)
|
110
109
|
self.int_token_logit_bias = torch.tensor(
|
111
110
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
112
111
|
)
|
@@ -117,13 +116,9 @@ class ModelTpServer:
|
|
117
116
|
f"[gpu_id={self.gpu_id}] "
|
118
117
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
119
118
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
119
|
+
f"max_running_requests={self.max_running_requests}, "
|
120
120
|
f"context_len={self.model_config.context_len}"
|
121
121
|
)
|
122
|
-
if self.tp_rank == 0:
|
123
|
-
logger.info(
|
124
|
-
f"[gpu_id={self.gpu_id}] "
|
125
|
-
f"server_args: {server_args.print_mode_args()}"
|
126
|
-
)
|
127
122
|
|
128
123
|
# Init cache
|
129
124
|
self.tree_cache = RadixCache(
|
@@ -165,22 +160,16 @@ class ModelTpServer:
|
|
165
160
|
assert (
|
166
161
|
server_args.schedule_conservativeness >= 0
|
167
162
|
), "Invalid schedule_conservativeness"
|
168
|
-
self.new_token_ratio = min(
|
169
|
-
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
170
|
-
1.0,
|
171
|
-
)
|
172
163
|
self.min_new_token_ratio = min(
|
173
164
|
global_config.base_min_new_token_ratio
|
174
165
|
* server_args.schedule_conservativeness,
|
175
166
|
1.0,
|
176
167
|
)
|
168
|
+
self.new_token_ratio = self.min_new_token_ratio
|
177
169
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
178
170
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
179
171
|
|
180
172
|
def exposed_step(self, recv_reqs):
|
181
|
-
if not isinstance(recv_reqs, list):
|
182
|
-
recv_reqs = obtain(recv_reqs)
|
183
|
-
|
184
173
|
try:
|
185
174
|
# Recv requests
|
186
175
|
for recv_req in recv_reqs:
|
@@ -228,23 +217,7 @@ class ModelTpServer:
|
|
228
217
|
|
229
218
|
# Print stats
|
230
219
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
231
|
-
|
232
|
-
self.token_to_kv_pool.available_size()
|
233
|
-
+ self.tree_cache.evictable_size()
|
234
|
-
)
|
235
|
-
throughput = self.num_generated_tokens / (
|
236
|
-
time.time() - self.last_stats_tic
|
237
|
-
)
|
238
|
-
self.num_generated_tokens = 0
|
239
|
-
self.last_stats_tic = time.time()
|
240
|
-
logger.info(
|
241
|
-
f"[gpu_id={self.gpu_id}] Decode batch. "
|
242
|
-
f"#running-req: {len(self.running_batch.reqs)}, "
|
243
|
-
f"#token: {num_used}, "
|
244
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
245
|
-
f"gen throughput (token/s): {throughput:.2f}, "
|
246
|
-
f"#queue-req: {len(self.forward_queue)}"
|
247
|
-
)
|
220
|
+
self.print_stats()
|
248
221
|
|
249
222
|
if self.running_batch.is_empty():
|
250
223
|
self.running_batch = None
|
@@ -253,17 +226,35 @@ class ModelTpServer:
|
|
253
226
|
if self.out_pyobjs and self.running_batch.has_stream():
|
254
227
|
break
|
255
228
|
else:
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
229
|
+
self.check_memory()
|
230
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
231
|
+
|
232
|
+
def print_stats(self):
|
233
|
+
num_used = self.max_total_num_tokens - (
|
234
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
235
|
+
)
|
236
|
+
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
237
|
+
self.num_generated_tokens = 0
|
238
|
+
self.last_stats_tic = time.time()
|
239
|
+
logger.info(
|
240
|
+
f"[gpu_id={self.gpu_id}] Decode batch. "
|
241
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
242
|
+
f"#token: {num_used}, "
|
243
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
244
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
245
|
+
f"#queue-req: {len(self.forward_queue)}"
|
246
|
+
)
|
247
|
+
|
248
|
+
def check_memory(self):
|
249
|
+
available_size = (
|
250
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
251
|
+
)
|
252
|
+
if available_size != self.max_total_num_tokens:
|
253
|
+
warnings.warn(
|
254
|
+
"Warning: "
|
255
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
256
|
+
"KV cache pool leak detected!"
|
257
|
+
)
|
267
258
|
|
268
259
|
def handle_generate_request(
|
269
260
|
self,
|
@@ -310,6 +301,12 @@ class ModelTpServer:
|
|
310
301
|
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
311
302
|
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
312
303
|
)
|
304
|
+
if req.sampling_params.max_new_tokens < 0:
|
305
|
+
req.origin_input_ids = req.origin_input_ids[
|
306
|
+
: self.max_total_num_tokens - 128
|
307
|
+
]
|
308
|
+
logger.error("Request longer than memory pool size, truncated!!!")
|
309
|
+
|
313
310
|
self.forward_queue.append(req)
|
314
311
|
|
315
312
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
@@ -343,7 +340,8 @@ class ModelTpServer:
|
|
343
340
|
if self.running_batch:
|
344
341
|
available_size -= sum(
|
345
342
|
[
|
346
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
343
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
344
|
+
* self.new_token_ratio
|
347
345
|
for r in self.running_batch.reqs
|
348
346
|
]
|
349
347
|
)
|
@@ -365,7 +363,9 @@ class ModelTpServer:
|
|
365
363
|
req.image_offset += 1
|
366
364
|
|
367
365
|
if (
|
368
|
-
req.extend_input_len
|
366
|
+
req.extend_input_len
|
367
|
+
+ req.sampling_params.max_new_tokens
|
368
|
+
+ new_batch_total_tokens
|
369
369
|
< available_size
|
370
370
|
and (
|
371
371
|
req.extend_input_len + new_batch_input_tokens
|
@@ -377,7 +377,9 @@ class ModelTpServer:
|
|
377
377
|
available_size += delta
|
378
378
|
|
379
379
|
if not (
|
380
|
-
req.extend_input_len
|
380
|
+
req.extend_input_len
|
381
|
+
+ req.sampling_params.max_new_tokens
|
382
|
+
+ new_batch_total_tokens
|
381
383
|
< available_size
|
382
384
|
):
|
383
385
|
# Undo locking
|
@@ -419,12 +421,6 @@ class ModelTpServer:
|
|
419
421
|
f"#running-req: {running_bs}, "
|
420
422
|
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
421
423
|
)
|
422
|
-
# logger.debug(
|
423
|
-
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
424
|
-
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
425
|
-
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
426
|
-
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
427
|
-
# )
|
428
424
|
|
429
425
|
# Return the new batch
|
430
426
|
new_batch = Batch.init_new(
|
@@ -445,7 +441,7 @@ class ModelTpServer:
|
|
445
441
|
# Forward and sample the next tokens
|
446
442
|
if batch.extend_num_tokens != 0:
|
447
443
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
448
|
-
next_token_ids
|
444
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
449
445
|
|
450
446
|
# Move logprobs to cpu
|
451
447
|
if output.next_token_logprobs is not None:
|
@@ -540,9 +536,10 @@ class ModelTpServer:
|
|
540
536
|
# Check if decode out of memory
|
541
537
|
if not batch.check_decode_mem():
|
542
538
|
old_ratio = self.new_token_ratio
|
543
|
-
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
544
539
|
|
545
|
-
retracted_reqs = batch.retract_decode()
|
540
|
+
retracted_reqs, new_token_ratio = batch.retract_decode()
|
541
|
+
self.new_token_ratio = new_token_ratio
|
542
|
+
|
546
543
|
logger.info(
|
547
544
|
"decode out of memory happened, "
|
548
545
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
@@ -568,7 +565,7 @@ class ModelTpServer:
|
|
568
565
|
|
569
566
|
# Forward and sample the next tokens
|
570
567
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
571
|
-
next_token_ids
|
568
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
572
569
|
|
573
570
|
# Move logprobs to cpu
|
574
571
|
if output.next_token_logprobs is not None:
|
@@ -596,9 +593,10 @@ class ModelTpServer:
|
|
596
593
|
|
597
594
|
def handle_finished_requests(self, batch: Batch):
|
598
595
|
output_rids = []
|
596
|
+
output_vids = []
|
599
597
|
decoded_texts = []
|
600
|
-
|
601
|
-
|
598
|
+
output_read_ids = []
|
599
|
+
output_read_offsets = []
|
602
600
|
output_skip_special_tokens = []
|
603
601
|
output_spaces_between_special_tokens = []
|
604
602
|
output_meta_info = []
|
@@ -621,10 +619,11 @@ class ModelTpServer:
|
|
621
619
|
)
|
622
620
|
):
|
623
621
|
output_rids.append(req.rid)
|
622
|
+
output_vids.append(req.vid)
|
624
623
|
decoded_texts.append(req.decoded_text)
|
625
|
-
|
626
|
-
|
627
|
-
|
624
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
625
|
+
output_read_ids.append(read_ids)
|
626
|
+
output_read_offsets.append(read_offset)
|
628
627
|
output_skip_special_tokens.append(
|
629
628
|
req.sampling_params.skip_special_tokens
|
630
629
|
)
|
@@ -660,9 +659,10 @@ class ModelTpServer:
|
|
660
659
|
self.out_pyobjs.append(
|
661
660
|
BatchTokenIDOut(
|
662
661
|
output_rids,
|
662
|
+
output_vids,
|
663
663
|
decoded_texts,
|
664
|
-
|
665
|
-
|
664
|
+
output_read_ids,
|
665
|
+
output_read_offsets,
|
666
666
|
output_skip_special_tokens,
|
667
667
|
output_spaces_between_special_tokens,
|
668
668
|
output_meta_info,
|
@@ -727,87 +727,74 @@ class ModelTpServer:
|
|
727
727
|
break
|
728
728
|
|
729
729
|
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
730
|
+
def run_tp_server(
|
731
|
+
gpu_id: int,
|
732
|
+
tp_rank: int,
|
733
|
+
server_args: ServerArgs,
|
734
|
+
nccl_port: int,
|
735
|
+
model_overide_args: dict,
|
736
|
+
):
|
737
|
+
"""Run a tensor parallel server."""
|
738
|
+
try:
|
739
|
+
model_server = ModelTpServer(
|
740
|
+
gpu_id,
|
741
|
+
tp_rank,
|
742
|
+
server_args,
|
743
|
+
nccl_port,
|
744
|
+
model_overide_args,
|
745
|
+
)
|
746
|
+
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
747
|
+
|
748
|
+
while True:
|
749
|
+
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
750
|
+
model_server.exposed_step(recv_reqs)
|
751
|
+
except Exception:
|
752
|
+
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
753
|
+
raise
|
754
|
+
|
755
|
+
|
756
|
+
def launch_tp_servers(
|
757
|
+
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
|
758
|
+
):
|
759
|
+
"""Launch multiple tensor parallel servers."""
|
760
|
+
procs = []
|
761
|
+
for i in tp_rank_range:
|
762
|
+
proc = multiprocessing.Process(
|
763
|
+
target=run_tp_server,
|
764
|
+
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
|
765
|
+
)
|
766
|
+
proc.start()
|
767
|
+
procs.append(proc)
|
744
768
|
|
745
|
-
|
746
|
-
# Init model
|
747
|
-
assert len(gpu_ids) == 1
|
748
|
-
self.model_server = ModelTpService().exposed_ModelTpServer(
|
749
|
-
gpu_ids[0],
|
750
|
-
0,
|
751
|
-
server_args,
|
752
|
-
model_port_args,
|
753
|
-
model_overide_args,
|
754
|
-
)
|
769
|
+
return procs
|
755
770
|
|
756
|
-
# Wrap functions
|
757
|
-
def async_wrap(f):
|
758
|
-
async def _func(*args, **kwargs):
|
759
|
-
return f(*args, **kwargs)
|
760
771
|
|
761
|
-
|
772
|
+
def broadcast_recv_input(data, rank, dist_group):
|
773
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
762
774
|
|
763
|
-
|
775
|
+
if rank == 0:
|
776
|
+
if len(data) == 0:
|
777
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
778
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
764
779
|
else:
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
self.model_services = list(
|
787
|
-
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
788
|
-
)
|
789
|
-
|
790
|
-
# Init model
|
791
|
-
def init_model(i):
|
792
|
-
return self.model_services[i].ModelTpServer(
|
793
|
-
gpu_ids[i],
|
794
|
-
i,
|
795
|
-
server_args,
|
796
|
-
model_port_args,
|
797
|
-
model_overide_args,
|
798
|
-
)
|
799
|
-
|
800
|
-
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
801
|
-
|
802
|
-
# Wrap functions
|
803
|
-
def async_wrap(func_name):
|
804
|
-
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
805
|
-
|
806
|
-
async def _func(*args, **kwargs):
|
807
|
-
tasks = [f(*args, **kwargs) for f in fs]
|
808
|
-
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
809
|
-
return obtain(tasks[0].value)
|
810
|
-
|
811
|
-
return _func
|
812
|
-
|
813
|
-
self.step = async_wrap("step")
|
780
|
+
serialized_data = pickle.dumps(data)
|
781
|
+
size = len(serialized_data)
|
782
|
+
tensor_data = torch.ByteTensor(list(serialized_data))
|
783
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
784
|
+
|
785
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
786
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
787
|
+
else:
|
788
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
789
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
790
|
+
size = tensor_size.item()
|
791
|
+
|
792
|
+
if size == 0:
|
793
|
+
return []
|
794
|
+
|
795
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
796
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
797
|
+
|
798
|
+
serialized_data = bytes(tensor_data.tolist())
|
799
|
+
data = pickle.loads(serialized_data)
|
800
|
+
return data
|
@@ -1,7 +1,9 @@
|
|
1
1
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import dataclasses
|
4
5
|
import inspect
|
6
|
+
from typing import List
|
5
7
|
|
6
8
|
import uvloop
|
7
9
|
import zmq
|
@@ -16,6 +18,15 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
|
|
16
18
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
17
19
|
|
18
20
|
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class DecodeStatus:
|
23
|
+
vid: int
|
24
|
+
decoded_text: str
|
25
|
+
decode_ids: List[int]
|
26
|
+
surr_offset: int
|
27
|
+
read_offset: int
|
28
|
+
|
29
|
+
|
19
30
|
class DetokenizerManager:
|
20
31
|
def __init__(
|
21
32
|
self,
|
@@ -35,19 +46,43 @@ class DetokenizerManager:
|
|
35
46
|
trust_remote_code=server_args.trust_remote_code,
|
36
47
|
)
|
37
48
|
|
49
|
+
self.decode_status = {}
|
50
|
+
|
38
51
|
async def handle_loop(self):
|
39
52
|
while True:
|
40
53
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
41
54
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
55
|
+
bs = len(recv_obj.rids)
|
56
|
+
|
57
|
+
# Initialize decode status
|
58
|
+
read_ids, surr_ids = [], []
|
59
|
+
for i in range(bs):
|
60
|
+
rid = recv_obj.rids[i]
|
61
|
+
vid = recv_obj.vids[i]
|
62
|
+
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
63
|
+
s = DecodeStatus(
|
64
|
+
vid=vid,
|
65
|
+
decoded_text=recv_obj.decoded_texts[i],
|
66
|
+
decode_ids=recv_obj.decode_ids[i],
|
67
|
+
surr_offset=0,
|
68
|
+
read_offset=recv_obj.read_offsets[i],
|
69
|
+
)
|
70
|
+
self.decode_status[rid] = s
|
71
|
+
else:
|
72
|
+
s = self.decode_status[rid]
|
73
|
+
s.decode_ids = recv_obj.decode_ids[i]
|
74
|
+
|
75
|
+
read_ids.append(s.decode_ids[s.surr_offset :])
|
76
|
+
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
42
77
|
|
43
78
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
79
|
surr_texts = self.tokenizer.batch_decode(
|
45
|
-
|
80
|
+
surr_ids,
|
46
81
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
82
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
83
|
)
|
49
84
|
read_texts = self.tokenizer.batch_decode(
|
50
|
-
|
85
|
+
read_ids,
|
51
86
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
52
87
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
53
88
|
)
|
@@ -55,11 +90,20 @@ class DetokenizerManager:
|
|
55
90
|
# Trim stop str
|
56
91
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
57
92
|
output_strs = []
|
58
|
-
for i in range(
|
93
|
+
for i in range(bs):
|
94
|
+
s = self.decode_status[recv_obj.rids[i]]
|
59
95
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
96
|
if recv_obj.finished_reason[i] is None:
|
61
|
-
|
62
|
-
|
97
|
+
# Streaming chunk: update the decode status
|
98
|
+
if len(new_text) > 0 and not new_text.endswith("�"):
|
99
|
+
s.decoded_text = s.decoded_text + new_text
|
100
|
+
s.surr_offset = s.read_offset
|
101
|
+
s.read_offset = len(s.decode_ids)
|
102
|
+
new_text = ""
|
103
|
+
else:
|
104
|
+
new_text = find_printable_text(new_text)
|
105
|
+
|
106
|
+
output_strs.append(s.decoded_text + new_text)
|
63
107
|
|
64
108
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
65
109
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -13,25 +13,26 @@ from sglang.srt.sampling_params import SamplingParams
|
|
13
13
|
|
14
14
|
@dataclass
|
15
15
|
class GenerateReqInput:
|
16
|
-
# The input prompt
|
16
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
17
17
|
text: Optional[Union[List[str], str]] = None
|
18
|
-
# The token ids for text; one can either specify text or input_ids
|
18
|
+
# The token ids for text; one can either specify text or input_ids.
|
19
19
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
20
|
-
# The image input
|
20
|
+
# The image input. It can be a file name, a url, or base64 encoded string.
|
21
|
+
# See also python/sglang/srt/utils.py:load_image.
|
21
22
|
image_data: Optional[Union[List[str], str]] = None
|
22
|
-
# The sampling_params
|
23
|
+
# The sampling_params.
|
23
24
|
sampling_params: Union[List[Dict], Dict] = None
|
24
|
-
# The request id
|
25
|
+
# The request id.
|
25
26
|
rid: Optional[Union[List[str], str]] = None
|
26
|
-
# Whether to return logprobs
|
27
|
+
# Whether to return logprobs.
|
27
28
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
28
|
-
# The start location of the prompt for return_logprob
|
29
|
+
# The start location of the prompt for return_logprob.
|
29
30
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
30
|
-
# The number of top logprobs to return
|
31
|
+
# The number of top logprobs to return.
|
31
32
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
32
|
-
# Whether to detokenize tokens in logprobs
|
33
|
+
# Whether to detokenize tokens in logprobs.
|
33
34
|
return_text_in_logprobs: bool = False
|
34
|
-
# Whether to stream output
|
35
|
+
# Whether to stream output.
|
35
36
|
stream: bool = False
|
36
37
|
|
37
38
|
def post_init(self):
|
@@ -39,11 +40,13 @@ class GenerateReqInput:
|
|
39
40
|
self.text is not None and self.input_ids is not None
|
40
41
|
):
|
41
42
|
raise ValueError("Either text or input_ids should be provided.")
|
42
|
-
|
43
|
-
|
44
|
-
is_single = isinstance(self.text, str)
|
43
|
+
if self.sampling_params.get("n", 1) != 1:
|
44
|
+
is_single = False
|
45
45
|
else:
|
46
|
-
|
46
|
+
if self.text is not None:
|
47
|
+
is_single = isinstance(self.text, str)
|
48
|
+
else:
|
49
|
+
is_single = isinstance(self.input_ids[0], int)
|
47
50
|
self.is_single = is_single
|
48
51
|
|
49
52
|
if is_single:
|
@@ -58,7 +61,22 @@ class GenerateReqInput:
|
|
58
61
|
if self.top_logprobs_num is None:
|
59
62
|
self.top_logprobs_num = 0
|
60
63
|
else:
|
61
|
-
|
64
|
+
|
65
|
+
parallel_sample_num = self.sampling_params.get("n", 1)
|
66
|
+
|
67
|
+
if parallel_sample_num != 1:
|
68
|
+
# parallel sampling +1 represents the original prefill stage
|
69
|
+
num = parallel_sample_num + 1
|
70
|
+
if isinstance(self.text, List):
|
71
|
+
## suppot batch operation
|
72
|
+
self.batch_size = len(self.text)
|
73
|
+
num = num * len(self.text)
|
74
|
+
else:
|
75
|
+
self.batch_size = 1
|
76
|
+
else:
|
77
|
+
## support select operation
|
78
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
79
|
+
self.batch_size = num
|
62
80
|
|
63
81
|
if self.image_data is None:
|
64
82
|
self.image_data = [None] * num
|
@@ -110,9 +128,10 @@ class TokenizedGenerateReqInput:
|
|
110
128
|
@dataclass
|
111
129
|
class BatchTokenIDOut:
|
112
130
|
rids: List[str]
|
131
|
+
vids: List[int]
|
113
132
|
decoded_texts: List[str]
|
114
|
-
|
115
|
-
|
133
|
+
decode_ids: List[int]
|
134
|
+
read_offsets: List[int]
|
116
135
|
skip_special_tokens: List[bool]
|
117
136
|
spaces_between_special_tokens: List[bool]
|
118
137
|
meta_info: List[Dict]
|