sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
|
|
1
|
+
"""A tensor parallel worker."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import logging
|
3
5
|
import time
|
4
6
|
import warnings
|
5
7
|
from concurrent.futures import ThreadPoolExecutor
|
6
|
-
from typing import List
|
8
|
+
from typing import List, Optional
|
7
9
|
|
8
10
|
import rpyc
|
9
11
|
import torch
|
@@ -13,23 +15,30 @@ from sglang.global_config import global_config
|
|
13
15
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
14
16
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
15
17
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
18
|
+
from sglang.srt.managers.controller.infer_batch import (
|
19
|
+
FINISH_ABORT,
|
20
|
+
BaseFinishReason,
|
21
|
+
Batch,
|
22
|
+
ForwardMode,
|
23
|
+
Req,
|
24
|
+
)
|
25
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
26
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
27
|
+
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
16
28
|
from sglang.srt.managers.io_struct import (
|
17
29
|
AbortReq,
|
18
30
|
BatchTokenIDOut,
|
19
31
|
FlushCacheReq,
|
20
32
|
TokenizedGenerateReqInput,
|
21
33
|
)
|
22
|
-
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
23
|
-
from sglang.srt.managers.controller.model_runner import ModelRunner
|
24
|
-
from sglang.srt.managers.controller.radix_cache import RadixCache
|
25
|
-
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
26
34
|
from sglang.srt.model_config import ModelConfig
|
27
35
|
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
28
36
|
from sglang.srt.utils import (
|
37
|
+
connect_rpyc_service,
|
29
38
|
get_int_token_logit_bias,
|
30
39
|
is_multimodal_model,
|
31
40
|
set_random_seed,
|
32
|
-
|
41
|
+
start_rpyc_service_process,
|
33
42
|
suppress_other_loggers,
|
34
43
|
)
|
35
44
|
from sglang.utils import get_exception_traceback
|
@@ -88,16 +97,16 @@ class ModelTpServer:
|
|
88
97
|
trust_remote_code=server_args.trust_remote_code,
|
89
98
|
)
|
90
99
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
91
|
-
self.max_prefill_tokens =
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
100
|
+
self.max_prefill_tokens = (
|
101
|
+
4096
|
102
|
+
if server_args.max_prefill_tokens is None
|
103
|
+
else server_args.max_prefill_tokens
|
104
|
+
)
|
105
|
+
self.max_running_requests = (
|
106
|
+
self.max_total_num_tokens // 2
|
107
|
+
if server_args.max_running_requests is None
|
108
|
+
else server_args.max_running_requests
|
98
109
|
)
|
99
|
-
self.max_running_requests = (self.max_total_num_tokens // 2
|
100
|
-
if server_args.max_running_requests is None else server_args.max_running_requests)
|
101
110
|
self.int_token_logit_bias = torch.tensor(
|
102
111
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
103
112
|
)
|
@@ -108,7 +117,7 @@ class ModelTpServer:
|
|
108
117
|
f"[gpu_id={self.gpu_id}] "
|
109
118
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
110
119
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
111
|
-
f"context_len={self.model_config.context_len}
|
120
|
+
f"context_len={self.model_config.context_len}"
|
112
121
|
)
|
113
122
|
if self.tp_rank == 0:
|
114
123
|
logger.info(
|
@@ -242,7 +251,7 @@ class ModelTpServer:
|
|
242
251
|
self.running_batch = None
|
243
252
|
break
|
244
253
|
|
245
|
-
if self.out_pyobjs and self.running_batch.
|
254
|
+
if self.out_pyobjs and self.running_batch.has_stream():
|
246
255
|
break
|
247
256
|
else:
|
248
257
|
# Check the available size
|
@@ -271,13 +280,14 @@ class ModelTpServer:
|
|
271
280
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
272
281
|
]
|
273
282
|
req.image_size = recv_req.image_size
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
283
|
+
(
|
284
|
+
req.origin_input_ids,
|
285
|
+
req.image_offset,
|
286
|
+
) = self.model_runner.model.pad_input_ids(
|
287
|
+
req.origin_input_ids_unpadded,
|
288
|
+
req.pad_value,
|
289
|
+
req.pixel_values.shape,
|
290
|
+
req.image_size,
|
281
291
|
)
|
282
292
|
req.sampling_params = recv_req.sampling_params
|
283
293
|
req.return_logprob = recv_req.return_logprob
|
@@ -303,7 +313,7 @@ class ModelTpServer:
|
|
303
313
|
)
|
304
314
|
self.forward_queue.append(req)
|
305
315
|
|
306
|
-
def get_new_fill_batch(self):
|
316
|
+
def get_new_fill_batch(self) -> Optional[Batch]:
|
307
317
|
if (
|
308
318
|
self.running_batch is not None
|
309
319
|
and len(self.running_batch.reqs) > self.max_running_requests
|
@@ -312,10 +322,7 @@ class ModelTpServer:
|
|
312
322
|
|
313
323
|
# Compute matched prefix length
|
314
324
|
for req in self.forward_queue:
|
315
|
-
|
316
|
-
len(req.output_ids) == 0
|
317
|
-
), "The output ids should be empty when prefilling"
|
318
|
-
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
325
|
+
req.input_ids = req.origin_input_ids + req.output_ids
|
319
326
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
320
327
|
if req.return_logprob:
|
321
328
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
@@ -361,8 +368,11 @@ class ModelTpServer:
|
|
361
368
|
if (
|
362
369
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
363
370
|
< available_size
|
364
|
-
and
|
365
|
-
|
371
|
+
and (
|
372
|
+
req.extend_input_len + new_batch_input_tokens
|
373
|
+
<= self.max_prefill_tokens
|
374
|
+
or len(can_run_list) == 0
|
375
|
+
)
|
366
376
|
):
|
367
377
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
368
378
|
available_size += delta
|
@@ -401,7 +411,7 @@ class ModelTpServer:
|
|
401
411
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
402
412
|
)
|
403
413
|
logger.info(
|
404
|
-
f"[gpu_id={self.gpu_id}]
|
414
|
+
f"[gpu_id={self.gpu_id}] Prefill batch. "
|
405
415
|
f"#new-seq: {len(can_run_list)}, "
|
406
416
|
f"#new-token: {new_batch_input_tokens}, "
|
407
417
|
f"#cached-token: {hit_tokens}, "
|
@@ -432,97 +442,93 @@ class ModelTpServer:
|
|
432
442
|
self.model_config.vocab_size, self.int_token_logit_bias
|
433
443
|
)
|
434
444
|
|
445
|
+
# Forward and sample the next tokens
|
435
446
|
if batch.extend_num_tokens != 0:
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
444
|
-
if prefill_token_logprobs is not None:
|
445
|
-
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
446
|
-
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
447
|
-
|
448
|
-
next_token_ids, _ = batch.sample(logits)
|
449
|
-
|
450
|
-
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
451
|
-
if last_logprobs is not None:
|
452
|
-
last_token_logprobs = last_logprobs[
|
453
|
-
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
447
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
448
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
449
|
+
|
450
|
+
# Move logprobs to cpu
|
451
|
+
if output.next_token_logprobs is not None:
|
452
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
453
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
454
454
|
next_token_ids,
|
455
455
|
].tolist()
|
456
|
+
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
457
|
+
output.normalized_prompt_logprobs = (
|
458
|
+
output.normalized_prompt_logprobs.tolist()
|
459
|
+
)
|
456
460
|
|
457
461
|
next_token_ids = next_token_ids.tolist()
|
458
462
|
else:
|
459
463
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
460
464
|
|
461
|
-
# Check finish
|
465
|
+
# Check finish conditions
|
462
466
|
pt = 0
|
463
467
|
for i, req in enumerate(batch.reqs):
|
464
468
|
req.completion_tokens_wo_jump_forward += 1
|
465
|
-
req.output_ids
|
469
|
+
req.output_ids.append(next_token_ids[i])
|
466
470
|
req.check_finished()
|
467
471
|
|
468
472
|
if req.return_logprob:
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
if req.prefill_token_logprobs is None:
|
473
|
-
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
474
|
-
req.prefill_token_logprobs = list(
|
475
|
-
zip(
|
476
|
-
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
477
|
-
req.input_ids[-req.extend_input_len + 1 :],
|
478
|
-
)
|
479
|
-
)
|
480
|
-
if req.logprob_start_len == 0:
|
481
|
-
req.prefill_token_logprobs = [
|
482
|
-
(None, req.input_ids[0])
|
483
|
-
] + req.prefill_token_logprobs
|
484
|
-
|
485
|
-
if req.last_update_decode_tokens != 0:
|
486
|
-
req.decode_token_logprobs.extend(
|
487
|
-
list(
|
488
|
-
zip(
|
489
|
-
prefill_token_logprobs[
|
490
|
-
pt
|
491
|
-
+ req.extend_input_len
|
492
|
-
- req.last_update_decode_tokens : pt
|
493
|
-
+ req.extend_input_len
|
494
|
-
- 1
|
495
|
-
],
|
496
|
-
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
497
|
-
)
|
498
|
-
)
|
499
|
-
)
|
473
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
474
|
+
pt += req.extend_input_len
|
500
475
|
|
501
|
-
|
502
|
-
(last_token_logprobs[i], next_token_ids[i])
|
503
|
-
)
|
476
|
+
self.handle_finished_requests(batch)
|
504
477
|
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
if req.logprob_start_len == 0:
|
509
|
-
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
478
|
+
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
479
|
+
if req.normalized_prompt_logprob is None:
|
480
|
+
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
510
481
|
|
511
|
-
|
512
|
-
|
513
|
-
|
482
|
+
if req.prefill_token_logprobs is None:
|
483
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
484
|
+
req.prefill_token_logprobs = list(
|
485
|
+
zip(
|
486
|
+
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
487
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
488
|
+
)
|
489
|
+
)
|
490
|
+
if req.logprob_start_len == 0:
|
491
|
+
req.prefill_token_logprobs = [
|
492
|
+
(None, req.input_ids[0])
|
493
|
+
] + req.prefill_token_logprobs
|
494
|
+
|
495
|
+
if req.last_update_decode_tokens != 0:
|
496
|
+
req.decode_token_logprobs.extend(
|
497
|
+
list(
|
498
|
+
zip(
|
499
|
+
output.prefill_token_logprobs[
|
500
|
+
pt
|
501
|
+
+ req.extend_input_len
|
502
|
+
- req.last_update_decode_tokens : pt
|
503
|
+
+ req.extend_input_len
|
504
|
+
- 1
|
505
|
+
],
|
506
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
514
507
|
)
|
515
|
-
|
508
|
+
)
|
509
|
+
)
|
516
510
|
|
517
|
-
|
511
|
+
req.decode_token_logprobs.append(
|
512
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
513
|
+
)
|
518
514
|
|
519
|
-
|
515
|
+
if req.top_logprobs_num > 0:
|
516
|
+
if req.prefill_top_logprobs is None:
|
517
|
+
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
518
|
+
if req.logprob_start_len == 0:
|
519
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
520
|
+
|
521
|
+
if req.last_update_decode_tokens != 0:
|
522
|
+
req.decode_top_logprobs.extend(
|
523
|
+
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
524
|
+
)
|
525
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
520
526
|
|
521
527
|
def cache_filled_batch(self, batch: Batch):
|
522
528
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
523
529
|
for i, req in enumerate(batch.reqs):
|
524
530
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
525
|
-
token_ids=tuple(req.
|
531
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
526
532
|
last_uncached_pos=len(req.prefix_indices),
|
527
533
|
req_pool_idx=req_pool_indices_cpu[i],
|
528
534
|
del_in_memory_pool=False,
|
@@ -531,7 +537,7 @@ class ModelTpServer:
|
|
531
537
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
532
538
|
|
533
539
|
def forward_decode_batch(self, batch: Batch):
|
534
|
-
#
|
540
|
+
# Check if decode out of memory
|
535
541
|
if not batch.check_decode_mem():
|
536
542
|
old_ratio = self.new_token_ratio
|
537
543
|
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
@@ -550,9 +556,8 @@ class ModelTpServer:
|
|
550
556
|
)
|
551
557
|
|
552
558
|
if not self.disable_regex_jump_forward:
|
553
|
-
#
|
559
|
+
# Check for jump-forward
|
554
560
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
555
|
-
|
556
561
|
self.forward_queue.extend(jump_forward_reqs)
|
557
562
|
if batch.is_empty():
|
558
563
|
return
|
@@ -561,23 +566,19 @@ class ModelTpServer:
|
|
561
566
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
562
567
|
batch.prepare_for_decode()
|
563
568
|
|
564
|
-
# Forward
|
565
|
-
|
566
|
-
|
567
|
-
_,
|
568
|
-
_,
|
569
|
-
decode_top_logprobs,
|
570
|
-
last_logprobs,
|
571
|
-
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
572
|
-
next_token_ids, _ = batch.sample(logits)
|
573
|
-
next_token_ids = next_token_ids.tolist()
|
569
|
+
# Forward and sample the next tokens
|
570
|
+
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
571
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
574
572
|
|
575
|
-
#
|
576
|
-
if
|
577
|
-
|
578
|
-
torch.arange(len(
|
573
|
+
# Move logprobs to cpu
|
574
|
+
if output.next_token_logprobs is not None:
|
575
|
+
next_token_logprobs = output.next_token_logprobs[
|
576
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
577
|
+
next_token_ids,
|
579
578
|
].tolist()
|
580
579
|
|
580
|
+
next_token_ids = next_token_ids.tolist()
|
581
|
+
|
581
582
|
# Check finish condition
|
582
583
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
583
584
|
req.completion_tokens_wo_jump_forward += 1
|
@@ -585,17 +586,19 @@ class ModelTpServer:
|
|
585
586
|
req.check_finished()
|
586
587
|
|
587
588
|
if req.return_logprob:
|
588
|
-
req.decode_token_logprobs.append(
|
589
|
-
|
590
|
-
|
591
|
-
req.
|
589
|
+
req.decode_token_logprobs.append(
|
590
|
+
(next_token_logprobs[i], next_token_id)
|
591
|
+
)
|
592
|
+
if req.top_logprobs_num > 0:
|
593
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
592
594
|
|
593
595
|
self.handle_finished_requests(batch)
|
594
596
|
|
595
597
|
def handle_finished_requests(self, batch: Batch):
|
596
598
|
output_rids = []
|
597
|
-
|
598
|
-
|
599
|
+
decoded_texts = []
|
600
|
+
surr_output_ids = []
|
601
|
+
read_output_ids = []
|
599
602
|
output_skip_special_tokens = []
|
600
603
|
output_spaces_between_special_tokens = []
|
601
604
|
output_meta_info = []
|
@@ -618,8 +621,10 @@ class ModelTpServer:
|
|
618
621
|
)
|
619
622
|
):
|
620
623
|
output_rids.append(req.rid)
|
621
|
-
|
622
|
-
|
624
|
+
decoded_texts.append(req.decoded_text)
|
625
|
+
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
626
|
+
surr_output_ids.append(surr_ids)
|
627
|
+
read_output_ids.append(read_ids)
|
623
628
|
output_skip_special_tokens.append(
|
624
629
|
req.sampling_params.skip_special_tokens
|
625
630
|
)
|
@@ -629,7 +634,7 @@ class ModelTpServer:
|
|
629
634
|
|
630
635
|
meta_info = {
|
631
636
|
"prompt_tokens": len(req.origin_input_ids),
|
632
|
-
"completion_tokens": len(req.
|
637
|
+
"completion_tokens": len(req.output_ids),
|
633
638
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
634
639
|
"finish_reason": str(req.finished_reason),
|
635
640
|
}
|
@@ -655,8 +660,9 @@ class ModelTpServer:
|
|
655
660
|
self.out_pyobjs.append(
|
656
661
|
BatchTokenIDOut(
|
657
662
|
output_rids,
|
658
|
-
|
659
|
-
|
663
|
+
decoded_texts,
|
664
|
+
surr_output_ids,
|
665
|
+
read_output_ids,
|
660
666
|
output_skip_special_tokens,
|
661
667
|
output_spaces_between_special_tokens,
|
662
668
|
output_meta_info,
|
@@ -671,7 +677,7 @@ class ModelTpServer:
|
|
671
677
|
for i in finished_indices:
|
672
678
|
req = batch.reqs[i]
|
673
679
|
self.tree_cache.cache_req(
|
674
|
-
token_ids=tuple(req.
|
680
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
675
681
|
last_uncached_pos=len(req.prefix_indices),
|
676
682
|
req_pool_idx=req_pool_indices_cpu[i],
|
677
683
|
)
|
@@ -758,12 +764,28 @@ class ModelTpClient:
|
|
758
764
|
else:
|
759
765
|
with ThreadPoolExecutor(self.tp_size) as executor:
|
760
766
|
# Launch model processes
|
761
|
-
|
762
|
-
|
763
|
-
|
767
|
+
if server_args.nnodes == 1:
|
768
|
+
self.procs = list(
|
769
|
+
executor.map(
|
770
|
+
lambda args: start_rpyc_service_process(*args),
|
771
|
+
[
|
772
|
+
(ModelTpService, p)
|
773
|
+
for p in model_port_args.model_tp_ports
|
774
|
+
],
|
775
|
+
)
|
776
|
+
)
|
777
|
+
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
778
|
+
else:
|
779
|
+
addrs = [
|
780
|
+
(ip, port)
|
781
|
+
for ip, port in zip(
|
782
|
+
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
783
|
+
)
|
784
|
+
]
|
785
|
+
|
786
|
+
self.model_services = list(
|
787
|
+
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
764
788
|
)
|
765
|
-
self.model_services = [x[0] for x in rets]
|
766
|
-
self.procs = [x[1] for x in rets]
|
767
789
|
|
768
790
|
# Init model
|
769
791
|
def init_model(i):
|
@@ -775,7 +797,7 @@ class ModelTpClient:
|
|
775
797
|
model_overide_args,
|
776
798
|
)
|
777
799
|
|
778
|
-
self.model_servers = executor.map(init_model, range(self.tp_size))
|
800
|
+
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
779
801
|
|
780
802
|
# Wrap functions
|
781
803
|
def async_wrap(func_name):
|
@@ -788,4 +810,4 @@ class ModelTpClient:
|
|
788
810
|
|
789
811
|
return _func
|
790
812
|
|
791
|
-
self.step = async_wrap("step")
|
813
|
+
self.step = async_wrap("step")
|
@@ -1,3 +1,5 @@
|
|
1
|
+
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import inspect
|
3
5
|
|
@@ -6,10 +8,10 @@ import zmq
|
|
6
8
|
import zmq.asyncio
|
7
9
|
|
8
10
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
11
|
+
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
9
12
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
10
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
11
|
-
from sglang.utils import get_exception_traceback, graceful_registry
|
12
|
-
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
14
|
+
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
13
15
|
|
14
16
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
15
17
|
|
@@ -38,30 +40,26 @@ class DetokenizerManager:
|
|
38
40
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
39
41
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
40
42
|
|
41
|
-
output_tokens = recv_obj.output_tokens
|
42
|
-
|
43
43
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
|
-
|
45
|
-
|
44
|
+
surr_texts = self.tokenizer.batch_decode(
|
45
|
+
recv_obj.surr_output_ids,
|
46
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
|
+
)
|
49
|
+
read_texts = self.tokenizer.batch_decode(
|
50
|
+
recv_obj.read_output_ids,
|
46
51
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
-
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
48
|
-
0
|
49
|
-
],
|
52
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
50
53
|
)
|
51
54
|
|
52
55
|
# Trim stop str
|
53
56
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
)
|
59
|
-
|
60
|
-
first_token = first_token.decode("utf-8", errors="ignore")
|
61
|
-
if first_token.startswith("▁"):
|
62
|
-
output_strs[i] = " " + output_strs[i]
|
63
|
-
|
64
|
-
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
57
|
+
output_strs = []
|
58
|
+
for i in range(len(recv_obj.rids)):
|
59
|
+
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
|
+
if recv_obj.finished_reason[i] is None:
|
61
|
+
new_text = find_printable_text(new_text)
|
62
|
+
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
65
63
|
|
66
64
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
67
65
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
@@ -71,7 +69,7 @@ class DetokenizerManager:
|
|
71
69
|
self.send_to_tokenizer.send_pyobj(
|
72
70
|
BatchStrOut(
|
73
71
|
rids=recv_obj.rids,
|
74
|
-
|
72
|
+
output_strs=output_strs,
|
75
73
|
meta_info=recv_obj.meta_info,
|
76
74
|
finished_reason=recv_obj.finished_reason,
|
77
75
|
)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,9 +1,14 @@
|
|
1
|
+
"""
|
2
|
+
The definition of objects transfered between different
|
3
|
+
processes (TokenizerManager, DetokenizerManager, Controller).
|
4
|
+
"""
|
5
|
+
|
1
6
|
import uuid
|
2
7
|
from dataclasses import dataclass
|
3
8
|
from typing import Dict, List, Optional, Union
|
4
9
|
|
5
|
-
from sglang.srt.sampling_params import SamplingParams
|
6
10
|
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
11
|
+
from sglang.srt.sampling_params import SamplingParams
|
7
12
|
|
8
13
|
|
9
14
|
@dataclass
|
@@ -30,7 +35,6 @@ class GenerateReqInput:
|
|
30
35
|
stream: bool = False
|
31
36
|
|
32
37
|
def post_init(self):
|
33
|
-
|
34
38
|
if (self.text is None and self.input_ids is None) or (
|
35
39
|
self.text is not None and self.input_ids is not None
|
36
40
|
):
|
@@ -106,17 +110,19 @@ class TokenizedGenerateReqInput:
|
|
106
110
|
@dataclass
|
107
111
|
class BatchTokenIDOut:
|
108
112
|
rids: List[str]
|
109
|
-
|
110
|
-
|
113
|
+
decoded_texts: List[str]
|
114
|
+
surr_output_ids: List[List[int]]
|
115
|
+
read_output_ids: List[List[int]]
|
111
116
|
skip_special_tokens: List[bool]
|
112
117
|
spaces_between_special_tokens: List[bool]
|
113
118
|
meta_info: List[Dict]
|
114
119
|
finished_reason: List[BaseFinishReason]
|
115
120
|
|
121
|
+
|
116
122
|
@dataclass
|
117
123
|
class BatchStrOut:
|
118
124
|
rids: List[str]
|
119
|
-
|
125
|
+
output_strs: List[str]
|
120
126
|
meta_info: List[Dict]
|
121
127
|
finished_reason: List[BaseFinishReason]
|
122
128
|
|