sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,19 +15,21 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A tensor parallel worker."""
|
17
17
|
|
18
|
+
import json
|
18
19
|
import logging
|
19
20
|
import multiprocessing
|
20
21
|
import os
|
21
22
|
import pickle
|
22
23
|
import time
|
23
24
|
import warnings
|
24
|
-
from typing import Any, List, Optional
|
25
|
+
from typing import Any, List, Optional
|
25
26
|
|
26
27
|
import torch
|
27
28
|
import torch.distributed
|
28
29
|
import torch.distributed as dist
|
29
30
|
|
30
31
|
from sglang.global_config import global_config
|
32
|
+
from sglang.srt.configs.model_config import ModelConfig
|
31
33
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
34
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
35
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
@@ -51,8 +53,6 @@ from sglang.srt.managers.schedule_batch import (
|
|
51
53
|
)
|
52
54
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
53
55
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
54
|
-
from sglang.srt.model_config import ModelConfig
|
55
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
56
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
57
57
|
from sglang.srt.server_args import ServerArgs
|
58
58
|
from sglang.srt.utils import (
|
@@ -66,6 +66,7 @@ from sglang.utils import get_exception_traceback
|
|
66
66
|
logger = logging.getLogger(__name__)
|
67
67
|
|
68
68
|
|
69
|
+
# Crash on warning if we are running CI tests
|
69
70
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
70
71
|
|
71
72
|
|
@@ -76,26 +77,26 @@ class ModelTpServer:
|
|
76
77
|
tp_rank: int,
|
77
78
|
server_args: ServerArgs,
|
78
79
|
nccl_port: int,
|
79
|
-
model_override_args: dict,
|
80
80
|
):
|
81
81
|
suppress_other_loggers()
|
82
82
|
|
83
|
-
#
|
83
|
+
# Parse arguments
|
84
84
|
self.gpu_id = gpu_id
|
85
85
|
self.tp_rank = tp_rank
|
86
86
|
self.tp_size = server_args.tp_size
|
87
87
|
self.dp_size = server_args.dp_size
|
88
88
|
self.schedule_policy = server_args.schedule_policy
|
89
89
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
90
|
+
self.lora_paths = server_args.lora_paths
|
91
|
+
self.max_loras_per_batch = server_args.max_loras_per_batch
|
90
92
|
|
91
93
|
# Init model and tokenizer
|
92
94
|
self.model_config = ModelConfig(
|
93
95
|
server_args.model_path,
|
94
96
|
server_args.trust_remote_code,
|
95
97
|
context_length=server_args.context_length,
|
96
|
-
model_override_args=
|
98
|
+
model_override_args=json.loads(server_args.json_model_override_args),
|
97
99
|
)
|
98
|
-
|
99
100
|
self.model_runner = ModelRunner(
|
100
101
|
model_config=self.model_config,
|
101
102
|
mem_fraction_static=server_args.mem_fraction_static,
|
@@ -129,14 +130,14 @@ class ModelTpServer:
|
|
129
130
|
if server_args.max_running_requests is None
|
130
131
|
else server_args.max_running_requests
|
131
132
|
),
|
132
|
-
self.model_runner.req_to_token_pool.size
|
133
|
+
self.model_runner.req_to_token_pool.size,
|
133
134
|
)
|
134
135
|
self.max_req_input_len = min(
|
135
136
|
self.model_config.context_len - 1,
|
136
137
|
self.max_total_num_tokens - 1,
|
137
138
|
)
|
138
139
|
|
139
|
-
# Sync random seed
|
140
|
+
# Sync random seed across TP workers
|
140
141
|
server_args.random_seed = broadcast_recv_input(
|
141
142
|
[server_args.random_seed],
|
142
143
|
self.tp_rank,
|
@@ -144,7 +145,7 @@ class ModelTpServer:
|
|
144
145
|
)[0]
|
145
146
|
set_random_seed(server_args.random_seed)
|
146
147
|
|
147
|
-
# Print info
|
148
|
+
# Print debug info
|
148
149
|
logger.info(
|
149
150
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
150
151
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
@@ -181,7 +182,7 @@ class ModelTpServer:
|
|
181
182
|
self.num_generated_tokens = 0
|
182
183
|
self.last_stats_tic = time.time()
|
183
184
|
|
184
|
-
#
|
185
|
+
# Init chunked prefill
|
185
186
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
186
187
|
self.current_inflight_req = None
|
187
188
|
self.is_mixed_chunk = (
|
@@ -197,16 +198,6 @@ class ModelTpServer:
|
|
197
198
|
"trust_remote_code": server_args.trust_remote_code,
|
198
199
|
},
|
199
200
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
200
|
-
json_schema_mode=False,
|
201
|
-
)
|
202
|
-
self.json_fsm_cache = FSMCache(
|
203
|
-
server_args.tokenizer_path,
|
204
|
-
{
|
205
|
-
"tokenizer_mode": server_args.tokenizer_mode,
|
206
|
-
"trust_remote_code": server_args.trust_remote_code,
|
207
|
-
},
|
208
|
-
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
209
|
-
json_schema_mode=True,
|
210
201
|
)
|
211
202
|
self.jump_forward_cache = JumpForwardCache()
|
212
203
|
|
@@ -221,15 +212,18 @@ class ModelTpServer:
|
|
221
212
|
)
|
222
213
|
self.new_token_ratio = self.min_new_token_ratio
|
223
214
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
215
|
+
self.do_not_get_new_batch = False
|
224
216
|
|
225
217
|
def exposed_step(self, recv_reqs: List):
|
226
218
|
try:
|
227
219
|
# Recv requests
|
228
220
|
for recv_req in recv_reqs:
|
229
|
-
if isinstance(
|
230
|
-
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
231
|
-
):
|
221
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
232
222
|
self.handle_generate_request(recv_req)
|
223
|
+
self.do_not_get_new_batch = False
|
224
|
+
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
225
|
+
self.handle_embedding_request(recv_req)
|
226
|
+
self.do_not_get_new_batch = False
|
233
227
|
elif isinstance(recv_req, FlushCacheReq):
|
234
228
|
self.flush_cache()
|
235
229
|
elif isinstance(recv_req, AbortReq):
|
@@ -253,7 +247,11 @@ class ModelTpServer:
|
|
253
247
|
|
254
248
|
@torch.inference_mode()
|
255
249
|
def forward_step(self):
|
256
|
-
|
250
|
+
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
251
|
+
new_batch = None
|
252
|
+
else:
|
253
|
+
new_batch = self.get_new_prefill_batch()
|
254
|
+
self.do_not_get_new_batch = False
|
257
255
|
|
258
256
|
if new_batch is not None:
|
259
257
|
# Run a new prefill batch
|
@@ -280,7 +278,7 @@ class ModelTpServer:
|
|
280
278
|
self.running_batch = None
|
281
279
|
break
|
282
280
|
|
283
|
-
if self.out_pyobjs and self.running_batch.has_stream
|
281
|
+
if self.out_pyobjs and self.running_batch.has_stream:
|
284
282
|
break
|
285
283
|
else:
|
286
284
|
self.check_memory()
|
@@ -325,73 +323,102 @@ class ModelTpServer:
|
|
325
323
|
|
326
324
|
def handle_generate_request(
|
327
325
|
self,
|
328
|
-
recv_req:
|
326
|
+
recv_req: TokenizedGenerateReqInput,
|
329
327
|
):
|
330
|
-
|
328
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
329
|
+
req = Req(
|
330
|
+
recv_req.rid,
|
331
|
+
recv_req.input_text,
|
332
|
+
recv_req.input_ids,
|
333
|
+
lora_path=recv_req.lora_path,
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
331
337
|
req.tokenizer = self.tokenizer
|
332
338
|
req.sampling_params = recv_req.sampling_params
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
req.
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
339
|
+
req.pixel_values = recv_req.pixel_values
|
340
|
+
if req.pixel_values is not None:
|
341
|
+
# Use image hash as fake token_ids, which is then used
|
342
|
+
# for prefix matching
|
343
|
+
image_hash = hash(tuple(recv_req.image_hashes))
|
344
|
+
req.pad_value = [
|
345
|
+
(image_hash) % self.model_config.vocab_size,
|
346
|
+
(image_hash >> 16) % self.model_config.vocab_size,
|
347
|
+
(image_hash >> 32) % self.model_config.vocab_size,
|
348
|
+
(image_hash >> 64) % self.model_config.vocab_size,
|
349
|
+
]
|
350
|
+
req.image_sizes = recv_req.image_sizes
|
351
|
+
(
|
352
|
+
req.origin_input_ids,
|
353
|
+
req.image_offsets,
|
354
|
+
) = self.model_runner.model.pad_input_ids(
|
355
|
+
req.origin_input_ids_unpadded,
|
356
|
+
req.pad_value,
|
357
|
+
req.pixel_values,
|
358
|
+
req.image_sizes,
|
359
|
+
)
|
360
|
+
# Only when pixel values is not None we have modalities
|
361
|
+
req.modalities = recv_req.modalites
|
362
|
+
req.return_logprob = recv_req.return_logprob
|
363
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
364
|
+
req.stream = recv_req.stream
|
365
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
366
|
+
|
367
|
+
if req.logprob_start_len == -1:
|
368
|
+
# By default, only return the logprobs for output tokens
|
369
|
+
req.logprob_start_len = len(recv_req.input_ids) - 1
|
370
|
+
|
371
|
+
# Init regex FSM
|
372
|
+
if (
|
373
|
+
req.sampling_params.json_schema is not None
|
374
|
+
or req.sampling_params.regex is not None
|
375
|
+
):
|
361
376
|
if req.sampling_params.json_schema is not None:
|
362
|
-
req.regex_fsm, computed_regex_string = self.
|
363
|
-
req.sampling_params.json_schema
|
377
|
+
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
378
|
+
("json", req.sampling_params.json_schema)
|
364
379
|
)
|
365
|
-
if not self.disable_regex_jump_forward:
|
366
|
-
req.jump_forward_map = self.jump_forward_cache.query(
|
367
|
-
computed_regex_string
|
368
|
-
)
|
369
|
-
|
370
|
-
# Init regex fsm
|
371
380
|
elif req.sampling_params.regex is not None:
|
372
|
-
req.regex_fsm = self.regex_fsm_cache.query(
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
381
|
+
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
382
|
+
("regex", req.sampling_params.regex)
|
383
|
+
)
|
384
|
+
if not self.disable_regex_jump_forward:
|
385
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
386
|
+
computed_regex_string
|
387
|
+
)
|
377
388
|
|
378
389
|
# Truncate prompts that are too long
|
379
390
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
380
|
-
logger.
|
391
|
+
logger.warning(
|
381
392
|
"Request length is longer than the KV cache pool size or "
|
382
393
|
"the max context length. Truncated!!!"
|
383
394
|
)
|
384
395
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
396
|
+
req.sampling_params.max_new_tokens = min(
|
397
|
+
(
|
398
|
+
req.sampling_params.max_new_tokens
|
399
|
+
if req.sampling_params.max_new_tokens is not None
|
400
|
+
else 1 << 30
|
401
|
+
),
|
402
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
403
|
+
)
|
385
404
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
405
|
+
self.waiting_queue.append(req)
|
406
|
+
|
407
|
+
def handle_embedding_request(
|
408
|
+
self,
|
409
|
+
recv_req: TokenizedEmbeddingReqInput,
|
410
|
+
):
|
411
|
+
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
412
|
+
req.tokenizer = self.tokenizer
|
413
|
+
req.sampling_params = recv_req.sampling_params
|
414
|
+
|
415
|
+
# Truncate prompts that are too long
|
416
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
417
|
+
logger.warn(
|
418
|
+
"Request length is longer than the KV cache pool size or "
|
419
|
+
"the max context length. Truncated!!!"
|
394
420
|
)
|
421
|
+
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
395
422
|
|
396
423
|
self.waiting_queue.append(req)
|
397
424
|
|
@@ -409,6 +436,8 @@ class ModelTpServer:
|
|
409
436
|
|
410
437
|
adder = PrefillAdder(
|
411
438
|
self.tree_cache,
|
439
|
+
self.running_batch,
|
440
|
+
self.new_token_ratio,
|
412
441
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
413
442
|
self.max_prefill_tokens,
|
414
443
|
self.chunked_prefill_size,
|
@@ -416,7 +445,7 @@ class ModelTpServer:
|
|
416
445
|
)
|
417
446
|
|
418
447
|
if self.running_batch is not None:
|
419
|
-
adder.remove_running_tokens(self.running_batch
|
448
|
+
adder.remove_running_tokens(self.running_batch)
|
420
449
|
|
421
450
|
has_inflight = self.current_inflight_req is not None
|
422
451
|
if self.current_inflight_req is not None:
|
@@ -427,12 +456,30 @@ class ModelTpServer:
|
|
427
456
|
self.current_inflight_req
|
428
457
|
)
|
429
458
|
|
459
|
+
if self.lora_paths is not None:
|
460
|
+
lora_set = (
|
461
|
+
set([req.lora_path for req in self.running_batch.reqs])
|
462
|
+
if self.running_batch is not None
|
463
|
+
else set([])
|
464
|
+
)
|
465
|
+
|
430
466
|
for req in self.waiting_queue:
|
467
|
+
if adder.no_remaining_tokens():
|
468
|
+
break
|
431
469
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
470
|
+
if (
|
471
|
+
self.lora_paths is not None
|
472
|
+
and len(
|
473
|
+
lora_set
|
474
|
+
| set([req.lora_path for req in adder.can_run_list])
|
475
|
+
| set([req.lora_path])
|
476
|
+
)
|
477
|
+
> self.max_loras_per_batch
|
478
|
+
):
|
479
|
+
break
|
432
480
|
res = adder.add_one_req(req)
|
433
481
|
if (
|
434
482
|
not res
|
435
|
-
or adder.no_remaining_tokens()
|
436
483
|
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
437
484
|
):
|
438
485
|
break
|
@@ -504,10 +551,9 @@ class ModelTpServer:
|
|
504
551
|
if self.model_runner.is_generation:
|
505
552
|
# Forward and sample the next tokens
|
506
553
|
if batch.extend_num_tokens != 0:
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
next_token_ids = batch.check_sample_results(sample_output)
|
554
|
+
logits_output = self.model_runner.forward(batch)
|
555
|
+
next_token_ids = self.model_runner.sample(logits_output, batch)
|
556
|
+
|
511
557
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
512
558
|
next_token_ids
|
513
559
|
)
|
@@ -541,7 +587,7 @@ class ModelTpServer:
|
|
541
587
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
542
588
|
|
543
589
|
# Check finish conditions
|
544
|
-
|
590
|
+
logprob_pt = 0
|
545
591
|
for i, req in enumerate(batch.reqs):
|
546
592
|
if req is not self.current_inflight_req:
|
547
593
|
# Inflight reqs' prefill is not finished
|
@@ -565,13 +611,12 @@ class ModelTpServer:
|
|
565
611
|
self.req_to_token_pool.free(req.req_pool_idx)
|
566
612
|
|
567
613
|
if req.return_logprob:
|
568
|
-
self.add_logprob_return_values(
|
569
|
-
i, req,
|
614
|
+
logprob_pt += self.add_logprob_return_values(
|
615
|
+
i, req, logprob_pt, next_token_ids, logits_output
|
570
616
|
)
|
571
|
-
pt += req.extend_input_len
|
572
617
|
else:
|
573
618
|
assert batch.extend_num_tokens != 0
|
574
|
-
logits_output = self.model_runner.forward(batch
|
619
|
+
logits_output = self.model_runner.forward(batch)
|
575
620
|
embeddings = logits_output.embeddings.tolist()
|
576
621
|
|
577
622
|
# Check finish conditions
|
@@ -596,48 +641,63 @@ class ModelTpServer:
|
|
596
641
|
|
597
642
|
def add_logprob_return_values(
|
598
643
|
self,
|
599
|
-
i,
|
644
|
+
i: int,
|
600
645
|
req: Req,
|
601
646
|
pt: int,
|
602
647
|
next_token_ids: List[int],
|
603
648
|
output: LogitsProcessorOutput,
|
604
649
|
):
|
650
|
+
"""Attach logprobs to the return values."""
|
651
|
+
req.output_token_logprobs.append(
|
652
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
653
|
+
)
|
654
|
+
|
655
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
656
|
+
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
657
|
+
|
605
658
|
if req.normalized_prompt_logprob is None:
|
606
659
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
607
660
|
|
608
661
|
if req.input_token_logprobs is None:
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
662
|
+
input_token_logprobs = output.input_token_logprobs[
|
663
|
+
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
664
|
+
]
|
665
|
+
input_token_ids = req.fill_ids[
|
666
|
+
len(req.fill_ids)
|
667
|
+
- num_input_logprobs
|
668
|
+
+ 1 : len(req.fill_ids)
|
669
|
+
- req.last_update_decode_tokens
|
670
|
+
]
|
671
|
+
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
672
|
+
|
673
|
+
if (
|
674
|
+
req.logprob_start_len == 0
|
675
|
+
): # The first token does not have logprob, pad it.
|
617
676
|
req.input_token_logprobs = [
|
618
677
|
(None, req.fill_ids[0])
|
619
678
|
] + req.input_token_logprobs
|
620
679
|
|
621
680
|
if req.last_update_decode_tokens != 0:
|
681
|
+
# Some decode tokens are re-computed in an extend batch
|
622
682
|
req.output_token_logprobs.extend(
|
623
683
|
list(
|
624
684
|
zip(
|
625
685
|
output.input_token_logprobs[
|
626
686
|
pt
|
627
|
-
+
|
687
|
+
+ num_input_logprobs
|
688
|
+
- 1
|
628
689
|
- req.last_update_decode_tokens : pt
|
629
|
-
+
|
690
|
+
+ num_input_logprobs
|
630
691
|
- 1
|
631
692
|
],
|
632
|
-
req.fill_ids[
|
693
|
+
req.fill_ids[
|
694
|
+
len(req.fill_ids)
|
695
|
+
- req.last_update_decode_tokens : len(req.fill_ids)
|
696
|
+
],
|
633
697
|
)
|
634
698
|
)
|
635
699
|
)
|
636
700
|
|
637
|
-
req.output_token_logprobs.append(
|
638
|
-
(output.next_token_logprobs[i], next_token_ids[i])
|
639
|
-
)
|
640
|
-
|
641
701
|
if req.top_logprobs_num > 0:
|
642
702
|
if req.input_top_logprobs is None:
|
643
703
|
req.input_top_logprobs = output.input_top_logprobs[i]
|
@@ -646,10 +706,12 @@ class ModelTpServer:
|
|
646
706
|
|
647
707
|
if req.last_update_decode_tokens != 0:
|
648
708
|
req.output_top_logprobs.extend(
|
649
|
-
output.input_top_logprobs[i][-req.last_update_decode_tokens
|
709
|
+
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
650
710
|
)
|
651
711
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
652
712
|
|
713
|
+
return num_input_logprobs
|
714
|
+
|
653
715
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
654
716
|
# Check if decode out of memory
|
655
717
|
if not batch.check_decode_mem():
|
@@ -682,10 +744,8 @@ class ModelTpServer:
|
|
682
744
|
batch.prepare_for_decode()
|
683
745
|
|
684
746
|
# Forward and sample the next tokens
|
685
|
-
|
686
|
-
|
687
|
-
)
|
688
|
-
next_token_ids = batch.check_sample_results(sample_output)
|
747
|
+
logits_output = self.model_runner.forward(batch)
|
748
|
+
next_token_ids = self.model_runner.sample(logits_output, batch)
|
689
749
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
690
750
|
next_token_ids
|
691
751
|
)
|
@@ -700,6 +760,7 @@ class ModelTpServer:
|
|
700
760
|
next_token_ids = next_token_ids.tolist()
|
701
761
|
|
702
762
|
# Check finish condition
|
763
|
+
has_finished = False
|
703
764
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
704
765
|
req.completion_tokens_wo_jump_forward += 1
|
705
766
|
req.output_ids.append(next_token_id)
|
@@ -712,6 +773,7 @@ class ModelTpServer:
|
|
712
773
|
|
713
774
|
if req.finished():
|
714
775
|
self.tree_cache.cache_finished_req(req)
|
776
|
+
has_finished = True
|
715
777
|
|
716
778
|
if req.return_logprob:
|
717
779
|
req.output_token_logprobs.append(
|
@@ -720,6 +782,9 @@ class ModelTpServer:
|
|
720
782
|
if req.top_logprobs_num > 0:
|
721
783
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
722
784
|
|
785
|
+
if not has_finished:
|
786
|
+
self.do_not_get_new_batch = True
|
787
|
+
|
723
788
|
self.handle_finished_requests(batch)
|
724
789
|
|
725
790
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
@@ -769,7 +834,11 @@ class ModelTpServer:
|
|
769
834
|
"prompt_tokens": len(req.origin_input_ids),
|
770
835
|
"completion_tokens": len(req.output_ids),
|
771
836
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
772
|
-
"finish_reason":
|
837
|
+
"finish_reason": (
|
838
|
+
req.finished_reason.to_json()
|
839
|
+
if req.finished_reason is not None
|
840
|
+
else None
|
841
|
+
),
|
773
842
|
}
|
774
843
|
if req.return_logprob:
|
775
844
|
(
|
@@ -876,7 +945,6 @@ def run_tp_server(
|
|
876
945
|
tp_rank: int,
|
877
946
|
server_args: ServerArgs,
|
878
947
|
nccl_port: int,
|
879
|
-
model_override_args: dict,
|
880
948
|
):
|
881
949
|
"""Run a tensor parallel model server."""
|
882
950
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -887,7 +955,6 @@ def run_tp_server(
|
|
887
955
|
tp_rank,
|
888
956
|
server_args,
|
889
957
|
nccl_port,
|
890
|
-
model_override_args,
|
891
958
|
)
|
892
959
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
893
960
|
|
@@ -904,14 +971,13 @@ def launch_tp_servers(
|
|
904
971
|
tp_rank_range: List[int],
|
905
972
|
server_args: ServerArgs,
|
906
973
|
nccl_port: int,
|
907
|
-
model_override_args: dict,
|
908
974
|
):
|
909
975
|
"""Launch multiple tensor parallel servers."""
|
910
976
|
procs = []
|
911
977
|
for i in tp_rank_range:
|
912
978
|
proc = multiprocessing.Process(
|
913
979
|
target=run_tp_server,
|
914
|
-
args=(gpu_ids[i], i, server_args, nccl_port
|
980
|
+
args=(gpu_ids[i], i, server_args, nccl_port),
|
915
981
|
)
|
916
982
|
proc.start()
|
917
983
|
procs.append(proc)
|