sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
|
|
1
|
+
"""A tensor parallel worker."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import logging
|
3
5
|
import time
|
4
6
|
import warnings
|
5
7
|
from concurrent.futures import ThreadPoolExecutor
|
6
|
-
from typing import List
|
8
|
+
from typing import List, Optional
|
7
9
|
|
8
10
|
import rpyc
|
9
11
|
import torch
|
@@ -13,23 +15,30 @@ from sglang.global_config import global_config
|
|
13
15
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
14
16
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
15
17
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
18
|
+
from sglang.srt.managers.controller.infer_batch import (
|
19
|
+
FINISH_ABORT,
|
20
|
+
BaseFinishReason,
|
21
|
+
Batch,
|
22
|
+
ForwardMode,
|
23
|
+
Req,
|
24
|
+
)
|
25
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
26
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
27
|
+
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
16
28
|
from sglang.srt.managers.io_struct import (
|
17
29
|
AbortReq,
|
18
30
|
BatchTokenIDOut,
|
19
31
|
FlushCacheReq,
|
20
32
|
TokenizedGenerateReqInput,
|
21
33
|
)
|
22
|
-
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
23
|
-
from sglang.srt.managers.controller.model_runner import ModelRunner
|
24
|
-
from sglang.srt.managers.controller.radix_cache import RadixCache
|
25
|
-
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
26
34
|
from sglang.srt.model_config import ModelConfig
|
27
35
|
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
28
36
|
from sglang.srt.utils import (
|
29
37
|
get_int_token_logit_bias,
|
30
38
|
is_multimodal_model,
|
31
39
|
set_random_seed,
|
32
|
-
|
40
|
+
start_rpyc_service_process,
|
41
|
+
connect_rpyc_service,
|
33
42
|
suppress_other_loggers,
|
34
43
|
)
|
35
44
|
from sglang.utils import get_exception_traceback
|
@@ -88,16 +97,16 @@ class ModelTpServer:
|
|
88
97
|
trust_remote_code=server_args.trust_remote_code,
|
89
98
|
)
|
90
99
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
91
|
-
self.max_prefill_tokens =
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
100
|
+
self.max_prefill_tokens = (
|
101
|
+
4096
|
102
|
+
if server_args.max_prefill_tokens is None
|
103
|
+
else server_args.max_prefill_tokens
|
104
|
+
)
|
105
|
+
self.max_running_requests = (
|
106
|
+
self.max_total_num_tokens // 2
|
107
|
+
if server_args.max_running_requests is None
|
108
|
+
else server_args.max_running_requests
|
98
109
|
)
|
99
|
-
self.max_running_requests = (self.max_total_num_tokens // 2
|
100
|
-
if server_args.max_running_requests is None else server_args.max_running_requests)
|
101
110
|
self.int_token_logit_bias = torch.tensor(
|
102
111
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
103
112
|
)
|
@@ -108,7 +117,7 @@ class ModelTpServer:
|
|
108
117
|
f"[gpu_id={self.gpu_id}] "
|
109
118
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
110
119
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
111
|
-
f"context_len={self.model_config.context_len}
|
120
|
+
f"context_len={self.model_config.context_len}"
|
112
121
|
)
|
113
122
|
if self.tp_rank == 0:
|
114
123
|
logger.info(
|
@@ -242,7 +251,7 @@ class ModelTpServer:
|
|
242
251
|
self.running_batch = None
|
243
252
|
break
|
244
253
|
|
245
|
-
if self.out_pyobjs and self.running_batch.
|
254
|
+
if self.out_pyobjs and self.running_batch.has_stream():
|
246
255
|
break
|
247
256
|
else:
|
248
257
|
# Check the available size
|
@@ -271,13 +280,14 @@ class ModelTpServer:
|
|
271
280
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
272
281
|
]
|
273
282
|
req.image_size = recv_req.image_size
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
283
|
+
(
|
284
|
+
req.origin_input_ids,
|
285
|
+
req.image_offset,
|
286
|
+
) = self.model_runner.model.pad_input_ids(
|
287
|
+
req.origin_input_ids_unpadded,
|
288
|
+
req.pad_value,
|
289
|
+
req.pixel_values.shape,
|
290
|
+
req.image_size,
|
281
291
|
)
|
282
292
|
req.sampling_params = recv_req.sampling_params
|
283
293
|
req.return_logprob = recv_req.return_logprob
|
@@ -303,7 +313,7 @@ class ModelTpServer:
|
|
303
313
|
)
|
304
314
|
self.forward_queue.append(req)
|
305
315
|
|
306
|
-
def get_new_fill_batch(self):
|
316
|
+
def get_new_fill_batch(self) -> Optional[Batch]:
|
307
317
|
if (
|
308
318
|
self.running_batch is not None
|
309
319
|
and len(self.running_batch.reqs) > self.max_running_requests
|
@@ -312,10 +322,7 @@ class ModelTpServer:
|
|
312
322
|
|
313
323
|
# Compute matched prefix length
|
314
324
|
for req in self.forward_queue:
|
315
|
-
|
316
|
-
len(req.output_ids) == 0
|
317
|
-
), "The output ids should be empty when prefilling"
|
318
|
-
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
325
|
+
req.input_ids = req.origin_input_ids + req.output_ids
|
319
326
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
320
327
|
if req.return_logprob:
|
321
328
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
@@ -361,8 +368,9 @@ class ModelTpServer:
|
|
361
368
|
if (
|
362
369
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
363
370
|
< available_size
|
364
|
-
and req.extend_input_len + new_batch_input_tokens
|
365
|
-
|
371
|
+
and (req.extend_input_len + new_batch_input_tokens
|
372
|
+
<= self.max_prefill_tokens
|
373
|
+
or len(can_run_list) == 0)
|
366
374
|
):
|
367
375
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
368
376
|
available_size += delta
|
@@ -401,7 +409,7 @@ class ModelTpServer:
|
|
401
409
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
402
410
|
)
|
403
411
|
logger.info(
|
404
|
-
f"[gpu_id={self.gpu_id}]
|
412
|
+
f"[gpu_id={self.gpu_id}] Prefill batch. "
|
405
413
|
f"#new-seq: {len(can_run_list)}, "
|
406
414
|
f"#new-token: {new_batch_input_tokens}, "
|
407
415
|
f"#cached-token: {hit_tokens}, "
|
@@ -432,97 +440,91 @@ class ModelTpServer:
|
|
432
440
|
self.model_config.vocab_size, self.int_token_logit_bias
|
433
441
|
)
|
434
442
|
|
443
|
+
# Forward and sample the next tokens
|
435
444
|
if batch.extend_num_tokens != 0:
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
444
|
-
if prefill_token_logprobs is not None:
|
445
|
-
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
446
|
-
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
447
|
-
|
448
|
-
next_token_ids, _ = batch.sample(logits)
|
449
|
-
|
450
|
-
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
451
|
-
if last_logprobs is not None:
|
452
|
-
last_token_logprobs = last_logprobs[
|
453
|
-
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
445
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
446
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
447
|
+
|
448
|
+
# Move logprobs to cpu
|
449
|
+
if output.next_token_logprobs is not None:
|
450
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
451
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
454
452
|
next_token_ids,
|
455
453
|
].tolist()
|
454
|
+
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
455
|
+
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
|
456
456
|
|
457
457
|
next_token_ids = next_token_ids.tolist()
|
458
458
|
else:
|
459
459
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
460
460
|
|
461
|
-
# Check finish
|
461
|
+
# Check finish conditions
|
462
462
|
pt = 0
|
463
463
|
for i, req in enumerate(batch.reqs):
|
464
464
|
req.completion_tokens_wo_jump_forward += 1
|
465
|
-
req.output_ids
|
465
|
+
req.output_ids.append(next_token_ids[i])
|
466
466
|
req.check_finished()
|
467
467
|
|
468
468
|
if req.return_logprob:
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
if req.prefill_token_logprobs is None:
|
473
|
-
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
474
|
-
req.prefill_token_logprobs = list(
|
475
|
-
zip(
|
476
|
-
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
477
|
-
req.input_ids[-req.extend_input_len + 1 :],
|
478
|
-
)
|
479
|
-
)
|
480
|
-
if req.logprob_start_len == 0:
|
481
|
-
req.prefill_token_logprobs = [
|
482
|
-
(None, req.input_ids[0])
|
483
|
-
] + req.prefill_token_logprobs
|
484
|
-
|
485
|
-
if req.last_update_decode_tokens != 0:
|
486
|
-
req.decode_token_logprobs.extend(
|
487
|
-
list(
|
488
|
-
zip(
|
489
|
-
prefill_token_logprobs[
|
490
|
-
pt
|
491
|
-
+ req.extend_input_len
|
492
|
-
- req.last_update_decode_tokens : pt
|
493
|
-
+ req.extend_input_len
|
494
|
-
- 1
|
495
|
-
],
|
496
|
-
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
497
|
-
)
|
498
|
-
)
|
499
|
-
)
|
469
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
470
|
+
pt += req.extend_input_len
|
500
471
|
|
501
|
-
|
502
|
-
(last_token_logprobs[i], next_token_ids[i])
|
503
|
-
)
|
472
|
+
self.handle_finished_requests(batch)
|
504
473
|
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
if req.logprob_start_len == 0:
|
509
|
-
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
474
|
+
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
475
|
+
if req.normalized_prompt_logprob is None:
|
476
|
+
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
510
477
|
|
511
|
-
|
512
|
-
|
513
|
-
|
478
|
+
if req.prefill_token_logprobs is None:
|
479
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
480
|
+
req.prefill_token_logprobs = list(
|
481
|
+
zip(
|
482
|
+
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
483
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
484
|
+
)
|
485
|
+
)
|
486
|
+
if req.logprob_start_len == 0:
|
487
|
+
req.prefill_token_logprobs = [
|
488
|
+
(None, req.input_ids[0])
|
489
|
+
] + req.prefill_token_logprobs
|
490
|
+
|
491
|
+
if req.last_update_decode_tokens != 0:
|
492
|
+
req.decode_token_logprobs.extend(
|
493
|
+
list(
|
494
|
+
zip(
|
495
|
+
output.prefill_token_logprobs[
|
496
|
+
pt
|
497
|
+
+ req.extend_input_len
|
498
|
+
- req.last_update_decode_tokens : pt
|
499
|
+
+ req.extend_input_len
|
500
|
+
- 1
|
501
|
+
],
|
502
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
514
503
|
)
|
515
|
-
|
504
|
+
)
|
505
|
+
)
|
506
|
+
|
507
|
+
req.decode_token_logprobs.append(
|
508
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
509
|
+
)
|
516
510
|
|
517
|
-
|
511
|
+
if req.top_logprobs_num > 0:
|
512
|
+
if req.prefill_top_logprobs is None:
|
513
|
+
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
514
|
+
if req.logprob_start_len == 0:
|
515
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
518
516
|
|
519
|
-
|
517
|
+
if req.last_update_decode_tokens != 0:
|
518
|
+
req.decode_top_logprobs.extend(
|
519
|
+
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
520
|
+
)
|
521
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
520
522
|
|
521
523
|
def cache_filled_batch(self, batch: Batch):
|
522
524
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
523
525
|
for i, req in enumerate(batch.reqs):
|
524
526
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
525
|
-
token_ids=tuple(req.
|
527
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
526
528
|
last_uncached_pos=len(req.prefix_indices),
|
527
529
|
req_pool_idx=req_pool_indices_cpu[i],
|
528
530
|
del_in_memory_pool=False,
|
@@ -531,7 +533,7 @@ class ModelTpServer:
|
|
531
533
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
532
534
|
|
533
535
|
def forward_decode_batch(self, batch: Batch):
|
534
|
-
#
|
536
|
+
# Check if decode out of memory
|
535
537
|
if not batch.check_decode_mem():
|
536
538
|
old_ratio = self.new_token_ratio
|
537
539
|
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
@@ -550,9 +552,8 @@ class ModelTpServer:
|
|
550
552
|
)
|
551
553
|
|
552
554
|
if not self.disable_regex_jump_forward:
|
553
|
-
#
|
555
|
+
# Check for jump-forward
|
554
556
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
555
|
-
|
556
557
|
self.forward_queue.extend(jump_forward_reqs)
|
557
558
|
if batch.is_empty():
|
558
559
|
return
|
@@ -561,23 +562,19 @@ class ModelTpServer:
|
|
561
562
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
562
563
|
batch.prepare_for_decode()
|
563
564
|
|
564
|
-
# Forward
|
565
|
-
|
566
|
-
|
567
|
-
_,
|
568
|
-
_,
|
569
|
-
decode_top_logprobs,
|
570
|
-
last_logprobs,
|
571
|
-
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
572
|
-
next_token_ids, _ = batch.sample(logits)
|
573
|
-
next_token_ids = next_token_ids.tolist()
|
565
|
+
# Forward and sample the next tokens
|
566
|
+
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
567
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
574
568
|
|
575
|
-
#
|
576
|
-
if
|
577
|
-
|
578
|
-
torch.arange(len(
|
569
|
+
# Move logprobs to cpu
|
570
|
+
if output.next_token_logprobs is not None:
|
571
|
+
next_token_logprobs = output.next_token_logprobs[
|
572
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
573
|
+
next_token_ids,
|
579
574
|
].tolist()
|
580
575
|
|
576
|
+
next_token_ids = next_token_ids.tolist()
|
577
|
+
|
581
578
|
# Check finish condition
|
582
579
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
583
580
|
req.completion_tokens_wo_jump_forward += 1
|
@@ -585,17 +582,17 @@ class ModelTpServer:
|
|
585
582
|
req.check_finished()
|
586
583
|
|
587
584
|
if req.return_logprob:
|
588
|
-
req.decode_token_logprobs.append((
|
589
|
-
|
590
|
-
|
591
|
-
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
585
|
+
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
|
586
|
+
if req.top_logprobs_num > 0:
|
587
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
592
588
|
|
593
589
|
self.handle_finished_requests(batch)
|
594
590
|
|
595
591
|
def handle_finished_requests(self, batch: Batch):
|
596
592
|
output_rids = []
|
597
|
-
|
598
|
-
|
593
|
+
decoded_texts = []
|
594
|
+
surr_output_ids = []
|
595
|
+
read_output_ids = []
|
599
596
|
output_skip_special_tokens = []
|
600
597
|
output_spaces_between_special_tokens = []
|
601
598
|
output_meta_info = []
|
@@ -618,8 +615,10 @@ class ModelTpServer:
|
|
618
615
|
)
|
619
616
|
):
|
620
617
|
output_rids.append(req.rid)
|
621
|
-
|
622
|
-
|
618
|
+
decoded_texts.append(req.decoded_text)
|
619
|
+
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
620
|
+
surr_output_ids.append(surr_ids)
|
621
|
+
read_output_ids.append(read_ids)
|
623
622
|
output_skip_special_tokens.append(
|
624
623
|
req.sampling_params.skip_special_tokens
|
625
624
|
)
|
@@ -629,7 +628,7 @@ class ModelTpServer:
|
|
629
628
|
|
630
629
|
meta_info = {
|
631
630
|
"prompt_tokens": len(req.origin_input_ids),
|
632
|
-
"completion_tokens": len(req.
|
631
|
+
"completion_tokens": len(req.output_ids),
|
633
632
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
634
633
|
"finish_reason": str(req.finished_reason),
|
635
634
|
}
|
@@ -655,8 +654,9 @@ class ModelTpServer:
|
|
655
654
|
self.out_pyobjs.append(
|
656
655
|
BatchTokenIDOut(
|
657
656
|
output_rids,
|
658
|
-
|
659
|
-
|
657
|
+
decoded_texts,
|
658
|
+
surr_output_ids,
|
659
|
+
read_output_ids,
|
660
660
|
output_skip_special_tokens,
|
661
661
|
output_spaces_between_special_tokens,
|
662
662
|
output_meta_info,
|
@@ -671,7 +671,7 @@ class ModelTpServer:
|
|
671
671
|
for i in finished_indices:
|
672
672
|
req = batch.reqs[i]
|
673
673
|
self.tree_cache.cache_req(
|
674
|
-
token_ids=tuple(req.
|
674
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
675
675
|
last_uncached_pos=len(req.prefix_indices),
|
676
676
|
req_pool_idx=req_pool_indices_cpu[i],
|
677
677
|
)
|
@@ -758,12 +758,17 @@ class ModelTpClient:
|
|
758
758
|
else:
|
759
759
|
with ThreadPoolExecutor(self.tp_size) as executor:
|
760
760
|
# Launch model processes
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
761
|
+
if server_args.nnodes == 1:
|
762
|
+
self.procs = list(executor.map(
|
763
|
+
lambda args: start_rpyc_service_process(*args),
|
764
|
+
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
765
|
+
))
|
766
|
+
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
767
|
+
else:
|
768
|
+
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
|
769
|
+
|
770
|
+
self.model_services = list(executor.map(
|
771
|
+
lambda args: connect_rpyc_service(*args), addrs))
|
767
772
|
|
768
773
|
# Init model
|
769
774
|
def init_model(i):
|
@@ -775,7 +780,7 @@ class ModelTpClient:
|
|
775
780
|
model_overide_args,
|
776
781
|
)
|
777
782
|
|
778
|
-
self.model_servers = executor.map(init_model, range(self.tp_size))
|
783
|
+
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
779
784
|
|
780
785
|
# Wrap functions
|
781
786
|
def async_wrap(func_name):
|
@@ -788,4 +793,4 @@ class ModelTpClient:
|
|
788
793
|
|
789
794
|
return _func
|
790
795
|
|
791
|
-
self.step = async_wrap("step")
|
796
|
+
self.step = async_wrap("step")
|
@@ -1,3 +1,5 @@
|
|
1
|
+
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import inspect
|
3
5
|
|
@@ -6,10 +8,10 @@ import zmq
|
|
6
8
|
import zmq.asyncio
|
7
9
|
|
8
10
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
11
|
+
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
9
12
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
10
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
11
14
|
from sglang.utils import get_exception_traceback, graceful_registry
|
12
|
-
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
13
15
|
|
14
16
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
15
17
|
|
@@ -38,30 +40,24 @@ class DetokenizerManager:
|
|
38
40
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
39
41
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
40
42
|
|
41
|
-
output_tokens = recv_obj.output_tokens
|
42
|
-
|
43
43
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
|
-
|
45
|
-
|
44
|
+
surr_texts = self.tokenizer.batch_decode(
|
45
|
+
recv_obj.surr_output_ids,
|
46
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
|
+
)
|
49
|
+
read_texts = self.tokenizer.batch_decode(
|
50
|
+
recv_obj.read_output_ids,
|
46
51
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
-
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
48
|
-
0
|
49
|
-
],
|
52
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
50
53
|
)
|
51
54
|
|
52
55
|
# Trim stop str
|
53
56
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
)
|
59
|
-
if not isinstance(first_token, str):
|
60
|
-
first_token = first_token.decode("utf-8", errors="ignore")
|
61
|
-
if first_token.startswith("▁"):
|
62
|
-
output_strs[i] = " " + output_strs[i]
|
63
|
-
|
64
|
-
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
57
|
+
output_strs = []
|
58
|
+
for i in range(len(recv_obj.rids)):
|
59
|
+
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
|
+
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
65
61
|
|
66
62
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
67
63
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,9 +1,14 @@
|
|
1
|
+
"""
|
2
|
+
The definition of objects transfered between different
|
3
|
+
processes (TokenizerManager, DetokenizerManager, Controller).
|
4
|
+
"""
|
5
|
+
|
1
6
|
import uuid
|
2
7
|
from dataclasses import dataclass
|
3
8
|
from typing import Dict, List, Optional, Union
|
4
9
|
|
5
|
-
from sglang.srt.sampling_params import SamplingParams
|
6
10
|
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
11
|
+
from sglang.srt.sampling_params import SamplingParams
|
7
12
|
|
8
13
|
|
9
14
|
@dataclass
|
@@ -30,7 +35,6 @@ class GenerateReqInput:
|
|
30
35
|
stream: bool = False
|
31
36
|
|
32
37
|
def post_init(self):
|
33
|
-
|
34
38
|
if (self.text is None and self.input_ids is None) or (
|
35
39
|
self.text is not None and self.input_ids is not None
|
36
40
|
):
|
@@ -106,13 +110,15 @@ class TokenizedGenerateReqInput:
|
|
106
110
|
@dataclass
|
107
111
|
class BatchTokenIDOut:
|
108
112
|
rids: List[str]
|
109
|
-
|
110
|
-
|
113
|
+
decoded_texts: List[str]
|
114
|
+
surr_output_ids: List[List[int]]
|
115
|
+
read_output_ids: List[List[int]]
|
111
116
|
skip_special_tokens: List[bool]
|
112
117
|
spaces_between_special_tokens: List[bool]
|
113
118
|
meta_info: List[Dict]
|
114
119
|
finished_reason: List[BaseFinishReason]
|
115
120
|
|
121
|
+
|
116
122
|
@dataclass
|
117
123
|
class BatchStrOut:
|
118
124
|
rids: List[str]
|
@@ -1,10 +1,12 @@
|
|
1
|
+
"""TokenizerManager is a process that tokenizes the text."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import concurrent.futures
|
3
5
|
import dataclasses
|
4
6
|
import logging
|
5
7
|
import multiprocessing as mp
|
6
8
|
import os
|
7
|
-
from typing import
|
9
|
+
from typing import Dict, List
|
8
10
|
|
9
11
|
import numpy as np
|
10
12
|
import transformers
|
@@ -22,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
|
|
22
24
|
from sglang.srt.managers.io_struct import (
|
23
25
|
AbortReq,
|
24
26
|
BatchStrOut,
|
27
|
+
BatchTokenIDOut,
|
25
28
|
FlushCacheReq,
|
26
29
|
GenerateReqInput,
|
27
30
|
TokenizedGenerateReqInput,
|
28
31
|
)
|
29
|
-
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
30
32
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
31
33
|
from sglang.srt.sampling_params import SamplingParams
|
32
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -90,7 +92,7 @@ class TokenizerManager:
|
|
90
92
|
)
|
91
93
|
|
92
94
|
self.to_create_loop = True
|
93
|
-
self.rid_to_state: Dict[str, ReqState] = {}
|
95
|
+
self.rid_to_state: Dict[str, ReqState] = {}
|
94
96
|
|
95
97
|
async def get_pixel_values(self, image_data):
|
96
98
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
@@ -283,7 +285,7 @@ class TokenizerManager:
|
|
283
285
|
req = AbortReq(rid)
|
284
286
|
self.send_to_router.send_pyobj(req)
|
285
287
|
|
286
|
-
def create_abort_task(self, obj):
|
288
|
+
def create_abort_task(self, obj: GenerateReqInput):
|
287
289
|
# Abort the request if the client is disconnected.
|
288
290
|
async def abort_request():
|
289
291
|
await asyncio.sleep(3)
|
@@ -321,7 +323,6 @@ class TokenizerManager:
|
|
321
323
|
state.finished = recv_obj.finished_reason[i] is not None
|
322
324
|
state.event.set()
|
323
325
|
|
324
|
-
|
325
326
|
def convert_logprob_style(
|
326
327
|
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
327
328
|
):
|
@@ -333,15 +334,15 @@ class TokenizerManager:
|
|
333
334
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
334
335
|
)
|
335
336
|
if top_logprobs_num > 0:
|
336
|
-
ret["meta_info"][
|
337
|
-
|
338
|
-
|
339
|
-
|
337
|
+
ret["meta_info"][
|
338
|
+
"prefill_top_logprobs"
|
339
|
+
] = self.detokenize_top_logprobs_tokens(
|
340
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
340
341
|
)
|
341
|
-
ret["meta_info"][
|
342
|
-
|
343
|
-
|
344
|
-
|
342
|
+
ret["meta_info"][
|
343
|
+
"decode_top_logprobs"
|
344
|
+
] = self.detokenize_top_logprobs_tokens(
|
345
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
345
346
|
)
|
346
347
|
return ret
|
347
348
|
|