sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
- sglang/srt/managers/controller/infer_batch.py +47 -49
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +35 -23
- sglang/srt/managers/controller/tp_worker.py +127 -138
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +19 -6
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +113 -84
- sglang/srt/server_args.py +23 -15
- sglang/srt/utils.py +16 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
|
|
1
1
|
"""A tensor parallel worker."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import logging
|
4
|
+
import multiprocessing
|
5
|
+
import pickle
|
5
6
|
import time
|
6
7
|
import warnings
|
7
|
-
from concurrent.futures import ThreadPoolExecutor
|
8
8
|
from typing import List, Optional
|
9
9
|
|
10
|
-
import rpyc
|
11
10
|
import torch
|
12
|
-
|
11
|
+
import torch.distributed as dist
|
13
12
|
|
14
13
|
from sglang.global_config import global_config
|
15
14
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
|
|
32
31
|
TokenizedGenerateReqInput,
|
33
32
|
)
|
34
33
|
from sglang.srt.model_config import ModelConfig
|
35
|
-
from sglang.srt.server_args import
|
34
|
+
from sglang.srt.server_args import ServerArgs
|
36
35
|
from sglang.srt.utils import (
|
37
|
-
connect_rpyc_service,
|
38
36
|
get_int_token_logit_bias,
|
39
37
|
is_multimodal_model,
|
40
38
|
set_random_seed,
|
41
|
-
start_rpyc_service_process,
|
42
39
|
suppress_other_loggers,
|
43
40
|
)
|
44
41
|
from sglang.utils import get_exception_traceback
|
@@ -52,10 +49,9 @@ class ModelTpServer:
|
|
52
49
|
gpu_id: int,
|
53
50
|
tp_rank: int,
|
54
51
|
server_args: ServerArgs,
|
55
|
-
|
52
|
+
nccl_port: int,
|
56
53
|
model_overide_args: dict,
|
57
54
|
):
|
58
|
-
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
59
55
|
suppress_other_loggers()
|
60
56
|
|
61
57
|
# Copy arguments
|
@@ -79,7 +75,7 @@ class ModelTpServer:
|
|
79
75
|
gpu_id=gpu_id,
|
80
76
|
tp_rank=tp_rank,
|
81
77
|
tp_size=server_args.tp_size,
|
82
|
-
nccl_port=
|
78
|
+
nccl_port=nccl_port,
|
83
79
|
server_args=server_args,
|
84
80
|
)
|
85
81
|
|
@@ -178,9 +174,6 @@ class ModelTpServer:
|
|
178
174
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
179
175
|
|
180
176
|
def exposed_step(self, recv_reqs):
|
181
|
-
if not isinstance(recv_reqs, list):
|
182
|
-
recv_reqs = obtain(recv_reqs)
|
183
|
-
|
184
177
|
try:
|
185
178
|
# Recv requests
|
186
179
|
for recv_req in recv_reqs:
|
@@ -228,23 +221,7 @@ class ModelTpServer:
|
|
228
221
|
|
229
222
|
# Print stats
|
230
223
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
231
|
-
|
232
|
-
self.token_to_kv_pool.available_size()
|
233
|
-
+ self.tree_cache.evictable_size()
|
234
|
-
)
|
235
|
-
throughput = self.num_generated_tokens / (
|
236
|
-
time.time() - self.last_stats_tic
|
237
|
-
)
|
238
|
-
self.num_generated_tokens = 0
|
239
|
-
self.last_stats_tic = time.time()
|
240
|
-
logger.info(
|
241
|
-
f"[gpu_id={self.gpu_id}] Decode batch. "
|
242
|
-
f"#running-req: {len(self.running_batch.reqs)}, "
|
243
|
-
f"#token: {num_used}, "
|
244
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
245
|
-
f"gen throughput (token/s): {throughput:.2f}, "
|
246
|
-
f"#queue-req: {len(self.forward_queue)}"
|
247
|
-
)
|
224
|
+
self.print_stats()
|
248
225
|
|
249
226
|
if self.running_batch.is_empty():
|
250
227
|
self.running_batch = None
|
@@ -253,17 +230,34 @@ class ModelTpServer:
|
|
253
230
|
if self.out_pyobjs and self.running_batch.has_stream():
|
254
231
|
break
|
255
232
|
else:
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
233
|
+
self.check_memory()
|
234
|
+
|
235
|
+
def print_stats(self):
|
236
|
+
num_used = self.max_total_num_tokens - (
|
237
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
238
|
+
)
|
239
|
+
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
240
|
+
self.num_generated_tokens = 0
|
241
|
+
self.last_stats_tic = time.time()
|
242
|
+
logger.info(
|
243
|
+
f"[gpu_id={self.gpu_id}] Decode batch. "
|
244
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
245
|
+
f"#token: {num_used}, "
|
246
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
247
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
248
|
+
f"#queue-req: {len(self.forward_queue)}"
|
249
|
+
)
|
250
|
+
|
251
|
+
def check_memory(self):
|
252
|
+
available_size = (
|
253
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
254
|
+
)
|
255
|
+
if available_size != self.max_total_num_tokens:
|
256
|
+
warnings.warn(
|
257
|
+
"Warning: "
|
258
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
259
|
+
"KV cache pool leak detected!"
|
260
|
+
)
|
267
261
|
|
268
262
|
def handle_generate_request(
|
269
263
|
self,
|
@@ -310,6 +304,12 @@ class ModelTpServer:
|
|
310
304
|
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
311
305
|
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
312
306
|
)
|
307
|
+
if req.sampling_params.max_new_tokens < 0:
|
308
|
+
req.origin_input_ids = req.origin_input_ids[
|
309
|
+
: self.max_total_num_tokens - 128
|
310
|
+
]
|
311
|
+
logger.error("Request longer than memory pool size, truncated!!!")
|
312
|
+
|
313
313
|
self.forward_queue.append(req)
|
314
314
|
|
315
315
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
@@ -343,7 +343,8 @@ class ModelTpServer:
|
|
343
343
|
if self.running_batch:
|
344
344
|
available_size -= sum(
|
345
345
|
[
|
346
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
346
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
347
|
+
* self.new_token_ratio
|
347
348
|
for r in self.running_batch.reqs
|
348
349
|
]
|
349
350
|
)
|
@@ -365,7 +366,9 @@ class ModelTpServer:
|
|
365
366
|
req.image_offset += 1
|
366
367
|
|
367
368
|
if (
|
368
|
-
req.extend_input_len
|
369
|
+
req.extend_input_len
|
370
|
+
+ req.sampling_params.max_new_tokens
|
371
|
+
+ new_batch_total_tokens
|
369
372
|
< available_size
|
370
373
|
and (
|
371
374
|
req.extend_input_len + new_batch_input_tokens
|
@@ -377,7 +380,9 @@ class ModelTpServer:
|
|
377
380
|
available_size += delta
|
378
381
|
|
379
382
|
if not (
|
380
|
-
req.extend_input_len
|
383
|
+
req.extend_input_len
|
384
|
+
+ req.sampling_params.max_new_tokens
|
385
|
+
+ new_batch_total_tokens
|
381
386
|
< available_size
|
382
387
|
):
|
383
388
|
# Undo locking
|
@@ -419,12 +424,6 @@ class ModelTpServer:
|
|
419
424
|
f"#running-req: {running_bs}, "
|
420
425
|
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
421
426
|
)
|
422
|
-
# logger.debug(
|
423
|
-
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
424
|
-
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
425
|
-
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
426
|
-
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
427
|
-
# )
|
428
427
|
|
429
428
|
# Return the new batch
|
430
429
|
new_batch = Batch.init_new(
|
@@ -445,7 +444,7 @@ class ModelTpServer:
|
|
445
444
|
# Forward and sample the next tokens
|
446
445
|
if batch.extend_num_tokens != 0:
|
447
446
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
448
|
-
next_token_ids
|
447
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
449
448
|
|
450
449
|
# Move logprobs to cpu
|
451
450
|
if output.next_token_logprobs is not None:
|
@@ -568,7 +567,7 @@ class ModelTpServer:
|
|
568
567
|
|
569
568
|
# Forward and sample the next tokens
|
570
569
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
571
|
-
next_token_ids
|
570
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
572
571
|
|
573
572
|
# Move logprobs to cpu
|
574
573
|
if output.next_token_logprobs is not None:
|
@@ -596,9 +595,10 @@ class ModelTpServer:
|
|
596
595
|
|
597
596
|
def handle_finished_requests(self, batch: Batch):
|
598
597
|
output_rids = []
|
598
|
+
output_vids = []
|
599
599
|
decoded_texts = []
|
600
|
-
|
601
|
-
|
600
|
+
output_read_ids = []
|
601
|
+
output_read_offsets = []
|
602
602
|
output_skip_special_tokens = []
|
603
603
|
output_spaces_between_special_tokens = []
|
604
604
|
output_meta_info = []
|
@@ -621,10 +621,11 @@ class ModelTpServer:
|
|
621
621
|
)
|
622
622
|
):
|
623
623
|
output_rids.append(req.rid)
|
624
|
+
output_vids.append(req.vid)
|
624
625
|
decoded_texts.append(req.decoded_text)
|
625
|
-
|
626
|
-
|
627
|
-
|
626
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
627
|
+
output_read_ids.append(read_ids)
|
628
|
+
output_read_offsets.append(read_offset)
|
628
629
|
output_skip_special_tokens.append(
|
629
630
|
req.sampling_params.skip_special_tokens
|
630
631
|
)
|
@@ -660,9 +661,10 @@ class ModelTpServer:
|
|
660
661
|
self.out_pyobjs.append(
|
661
662
|
BatchTokenIDOut(
|
662
663
|
output_rids,
|
664
|
+
output_vids,
|
663
665
|
decoded_texts,
|
664
|
-
|
665
|
-
|
666
|
+
output_read_ids,
|
667
|
+
output_read_offsets,
|
666
668
|
output_skip_special_tokens,
|
667
669
|
output_spaces_between_special_tokens,
|
668
670
|
output_meta_info,
|
@@ -727,87 +729,74 @@ class ModelTpServer:
|
|
727
729
|
break
|
728
730
|
|
729
731
|
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
732
|
+
def run_tp_server(
|
733
|
+
gpu_id: int,
|
734
|
+
tp_rank: int,
|
735
|
+
server_args: ServerArgs,
|
736
|
+
nccl_port: int,
|
737
|
+
model_overide_args: dict,
|
738
|
+
):
|
739
|
+
"""Run a tensor parallel server."""
|
740
|
+
try:
|
741
|
+
model_server = ModelTpServer(
|
742
|
+
gpu_id,
|
743
|
+
tp_rank,
|
744
|
+
server_args,
|
745
|
+
nccl_port,
|
746
|
+
model_overide_args,
|
747
|
+
)
|
748
|
+
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
749
|
+
|
750
|
+
while True:
|
751
|
+
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
752
|
+
model_server.exposed_step(recv_reqs)
|
753
|
+
except Exception:
|
754
|
+
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
755
|
+
raise
|
756
|
+
|
757
|
+
|
758
|
+
def launch_tp_servers(
|
759
|
+
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
|
760
|
+
):
|
761
|
+
"""Launch multiple tensor parallel servers."""
|
762
|
+
procs = []
|
763
|
+
for i in tp_rank_range:
|
764
|
+
proc = multiprocessing.Process(
|
765
|
+
target=run_tp_server,
|
766
|
+
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
|
767
|
+
)
|
768
|
+
proc.start()
|
769
|
+
procs.append(proc)
|
744
770
|
|
745
|
-
|
746
|
-
# Init model
|
747
|
-
assert len(gpu_ids) == 1
|
748
|
-
self.model_server = ModelTpService().exposed_ModelTpServer(
|
749
|
-
gpu_ids[0],
|
750
|
-
0,
|
751
|
-
server_args,
|
752
|
-
model_port_args,
|
753
|
-
model_overide_args,
|
754
|
-
)
|
771
|
+
return procs
|
755
772
|
|
756
|
-
# Wrap functions
|
757
|
-
def async_wrap(f):
|
758
|
-
async def _func(*args, **kwargs):
|
759
|
-
return f(*args, **kwargs)
|
760
773
|
|
761
|
-
|
774
|
+
def broadcast_recv_input(data, rank, dist_group):
|
775
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
762
776
|
|
763
|
-
|
777
|
+
if rank == 0:
|
778
|
+
if len(data) == 0:
|
779
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
780
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
764
781
|
else:
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
self.model_services = list(
|
787
|
-
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
788
|
-
)
|
789
|
-
|
790
|
-
# Init model
|
791
|
-
def init_model(i):
|
792
|
-
return self.model_services[i].ModelTpServer(
|
793
|
-
gpu_ids[i],
|
794
|
-
i,
|
795
|
-
server_args,
|
796
|
-
model_port_args,
|
797
|
-
model_overide_args,
|
798
|
-
)
|
799
|
-
|
800
|
-
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
801
|
-
|
802
|
-
# Wrap functions
|
803
|
-
def async_wrap(func_name):
|
804
|
-
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
805
|
-
|
806
|
-
async def _func(*args, **kwargs):
|
807
|
-
tasks = [f(*args, **kwargs) for f in fs]
|
808
|
-
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
809
|
-
return obtain(tasks[0].value)
|
810
|
-
|
811
|
-
return _func
|
812
|
-
|
813
|
-
self.step = async_wrap("step")
|
782
|
+
serialized_data = pickle.dumps(data)
|
783
|
+
size = len(serialized_data)
|
784
|
+
tensor_data = torch.ByteTensor(list(serialized_data))
|
785
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
786
|
+
|
787
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
788
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
789
|
+
else:
|
790
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
791
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
792
|
+
size = tensor_size.item()
|
793
|
+
|
794
|
+
if size == 0:
|
795
|
+
return []
|
796
|
+
|
797
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
798
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
799
|
+
|
800
|
+
serialized_data = bytes(tensor_data.tolist())
|
801
|
+
data = pickle.loads(serialized_data)
|
802
|
+
return data
|
@@ -1,7 +1,9 @@
|
|
1
1
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import dataclasses
|
4
5
|
import inspect
|
6
|
+
from typing import List
|
5
7
|
|
6
8
|
import uvloop
|
7
9
|
import zmq
|
@@ -16,6 +18,15 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
|
|
16
18
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
17
19
|
|
18
20
|
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class DecodeStatus:
|
23
|
+
vid: int
|
24
|
+
decoded_text: str
|
25
|
+
decode_ids: List[int]
|
26
|
+
surr_offset: int
|
27
|
+
read_offset: int
|
28
|
+
|
29
|
+
|
19
30
|
class DetokenizerManager:
|
20
31
|
def __init__(
|
21
32
|
self,
|
@@ -35,19 +46,43 @@ class DetokenizerManager:
|
|
35
46
|
trust_remote_code=server_args.trust_remote_code,
|
36
47
|
)
|
37
48
|
|
49
|
+
self.decode_status = {}
|
50
|
+
|
38
51
|
async def handle_loop(self):
|
39
52
|
while True:
|
40
53
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
41
54
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
55
|
+
bs = len(recv_obj.rids)
|
56
|
+
|
57
|
+
# Initialize decode status
|
58
|
+
read_ids, surr_ids = [], []
|
59
|
+
for i in range(bs):
|
60
|
+
rid = recv_obj.rids[i]
|
61
|
+
vid = recv_obj.vids[i]
|
62
|
+
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
63
|
+
s = DecodeStatus(
|
64
|
+
vid=vid,
|
65
|
+
decoded_text=recv_obj.decoded_texts[i],
|
66
|
+
decode_ids=recv_obj.decode_ids[i],
|
67
|
+
surr_offset=0,
|
68
|
+
read_offset=recv_obj.read_offsets[i],
|
69
|
+
)
|
70
|
+
self.decode_status[rid] = s
|
71
|
+
else:
|
72
|
+
s = self.decode_status[rid]
|
73
|
+
s.decode_ids = recv_obj.decode_ids[i]
|
74
|
+
|
75
|
+
read_ids.append(s.decode_ids[s.surr_offset :])
|
76
|
+
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
42
77
|
|
43
78
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
79
|
surr_texts = self.tokenizer.batch_decode(
|
45
|
-
|
80
|
+
surr_ids,
|
46
81
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
82
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
83
|
)
|
49
84
|
read_texts = self.tokenizer.batch_decode(
|
50
|
-
|
85
|
+
read_ids,
|
51
86
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
52
87
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
53
88
|
)
|
@@ -55,11 +90,20 @@ class DetokenizerManager:
|
|
55
90
|
# Trim stop str
|
56
91
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
57
92
|
output_strs = []
|
58
|
-
for i in range(
|
93
|
+
for i in range(bs):
|
94
|
+
s = self.decode_status[recv_obj.rids[i]]
|
59
95
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
96
|
if recv_obj.finished_reason[i] is None:
|
61
|
-
|
62
|
-
|
97
|
+
# Streaming chunk: update the decode status
|
98
|
+
if len(new_text) > 0 and not new_text.endswith("�"):
|
99
|
+
s.decoded_text = s.decoded_text + new_text
|
100
|
+
s.surr_offset = s.read_offset
|
101
|
+
s.read_offset = len(s.decode_ids)
|
102
|
+
new_text = ""
|
103
|
+
else:
|
104
|
+
new_text = find_printable_text(new_text)
|
105
|
+
|
106
|
+
output_strs.append(s.decoded_text + new_text)
|
63
107
|
|
64
108
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
65
109
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -13,25 +13,26 @@ from sglang.srt.sampling_params import SamplingParams
|
|
13
13
|
|
14
14
|
@dataclass
|
15
15
|
class GenerateReqInput:
|
16
|
-
# The input prompt
|
16
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
17
17
|
text: Optional[Union[List[str], str]] = None
|
18
|
-
# The token ids for text; one can either specify text or input_ids
|
18
|
+
# The token ids for text; one can either specify text or input_ids.
|
19
19
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
20
|
-
# The image input
|
20
|
+
# The image input. It can be a file name, a url, or base64 encoded string.
|
21
|
+
# See also python/sglang/srt/utils.py:load_image.
|
21
22
|
image_data: Optional[Union[List[str], str]] = None
|
22
|
-
# The sampling_params
|
23
|
+
# The sampling_params.
|
23
24
|
sampling_params: Union[List[Dict], Dict] = None
|
24
|
-
# The request id
|
25
|
+
# The request id.
|
25
26
|
rid: Optional[Union[List[str], str]] = None
|
26
|
-
# Whether to return logprobs
|
27
|
+
# Whether to return logprobs.
|
27
28
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
28
|
-
# The start location of the prompt for return_logprob
|
29
|
+
# The start location of the prompt for return_logprob.
|
29
30
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
30
|
-
# The number of top logprobs to return
|
31
|
+
# The number of top logprobs to return.
|
31
32
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
32
|
-
# Whether to detokenize tokens in logprobs
|
33
|
+
# Whether to detokenize tokens in logprobs.
|
33
34
|
return_text_in_logprobs: bool = False
|
34
|
-
# Whether to stream output
|
35
|
+
# Whether to stream output.
|
35
36
|
stream: bool = False
|
36
37
|
|
37
38
|
def post_init(self):
|
@@ -39,11 +40,13 @@ class GenerateReqInput:
|
|
39
40
|
self.text is not None and self.input_ids is not None
|
40
41
|
):
|
41
42
|
raise ValueError("Either text or input_ids should be provided.")
|
42
|
-
|
43
|
-
|
44
|
-
is_single = isinstance(self.text, str)
|
43
|
+
if self.sampling_params.get("n", 1) != 1:
|
44
|
+
is_single = False
|
45
45
|
else:
|
46
|
-
|
46
|
+
if self.text is not None:
|
47
|
+
is_single = isinstance(self.text, str)
|
48
|
+
else:
|
49
|
+
is_single = isinstance(self.input_ids[0], int)
|
47
50
|
self.is_single = is_single
|
48
51
|
|
49
52
|
if is_single:
|
@@ -58,7 +61,22 @@ class GenerateReqInput:
|
|
58
61
|
if self.top_logprobs_num is None:
|
59
62
|
self.top_logprobs_num = 0
|
60
63
|
else:
|
61
|
-
|
64
|
+
|
65
|
+
parallel_sample_num = self.sampling_params.get("n", 1)
|
66
|
+
|
67
|
+
if parallel_sample_num != 1:
|
68
|
+
# parallel sampling +1 represents the original prefill stage
|
69
|
+
num = parallel_sample_num + 1
|
70
|
+
if isinstance(self.text, List):
|
71
|
+
## suppot batch operation
|
72
|
+
self.batch_size = len(self.text)
|
73
|
+
num = num * len(self.text)
|
74
|
+
else:
|
75
|
+
self.batch_size = 1
|
76
|
+
else:
|
77
|
+
## support select operation
|
78
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
79
|
+
self.batch_size = num
|
62
80
|
|
63
81
|
if self.image_data is None:
|
64
82
|
self.image_data = [None] * num
|
@@ -110,9 +128,10 @@ class TokenizedGenerateReqInput:
|
|
110
128
|
@dataclass
|
111
129
|
class BatchTokenIDOut:
|
112
130
|
rids: List[str]
|
131
|
+
vids: List[int]
|
113
132
|
decoded_texts: List[str]
|
114
|
-
|
115
|
-
|
133
|
+
decode_ids: List[int]
|
134
|
+
read_offsets: List[int]
|
116
135
|
skip_special_tokens: List[bool]
|
117
136
|
spaces_between_special_tokens: List[bool]
|
118
137
|
meta_info: List[Dict]
|